去年一直立的flag,拖了许久。。终于下定决心要克服困难,真正去学习Pytorch的代码了。先从C部分看起,理解核心代码后再看上层

C-泛型

Pytorch使用了宏来实现c的泛型API,举个栗子

1
#define CONCAT(A, B, C) A ## B ## C

例如这个宏可以用来产生Double_Matrix_mul这个变量名

1
Double_Matrix CONCAT(Double_, Matrix_, mul)(Double_Matrix *A, Double_Matrix *B);

C的宏定义在出现###时会使用字符串替换,而不是展开,可以利用这个特性来产生函数名。同时需要一个中间宏来先展开宏名

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
#define CONCAT_2_EXPAND(A, B) A ## B         // NumVector
#define CONCAT_2(A, B) CONCAT_2_EXPAND(A, B)     //== #define CONCAT_2(Num,Vector) NumVector 
//#define CONCAT_2(A, B) A ## B  之所以不这样写,是因为这样写无法展开A,B,直接进行的字符串替换
#define CONCAT_3_EXPAND(A, B, C) A ## B ## C     
#define CONCAT_3(A, B, C) CONCAT_3_EXPAND(A, B, C)

#define Vector_(NAME) CONCAT_3(Num, Vector_, NAME)
#define Vector CONCAT_2(Num, Vector)

#define num float
#define Num Float
struct Vector
{
 num *data;
 int n;
};
void Vector_(add)(Vector *C, Vector *A, Vector *B) {
//codes
}

TH API

张量这个数学对象被TH分解为THTensorTHStorageTHTensor提供一种查看THStorage的方法,THStorage负责管理张量的存储方式。所有的THTensor类型最终都会替换成at::TensorImpl,所有的THStorage类型也都会替换成at::StorageImpl

THTensor

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
//Aten/src/TH/THTensor.h
/* fill and zero*/
#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateAllTypes.h>

#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateHalfType.h>

#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateBoolType.h>

刚开始看这个的时候一脸懵逼,为啥要反复include同一个文件。。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateBoolType.h"
#endif
#define scalar_t bool
#define TH_REAL_IS_BOOL
#line 1 TH_GENERIC_FILE     //
#include TH_GENERIC_FILE
#undef scalar_t
#undef TH_REAL_IS_BOOL
#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif

THGenerateBoolType.h举例子,泛型的函数与宏都写在THTensorFill.h里面,include这个.h后,TH_GENERIC_FILE得到了定义,再#include这个文件的时候,就会把这个文件里的泛型函数与宏进行展开特化,从而得到bool相关的特化函数,非常巧妙。通过这种方式,使得THTensor.h中包含了所有的c API函数定义

THStorage

Aten/src/TH/generic/THStorage.h

接口实现与THVector类似,

Util

Intrusive_ptr(c10/util/intrusive_ptr)

侵入式指针,用于解决shared_ptr无法适应的场景,区别是对象自己管理引用计数。性能好于shared_ptr 知乎链接