去年一直立的flag,拖了许久。。终于下定决心要克服困难,真正去学习Pytorch的代码了。先从C部分看起,理解核心代码后再看上层
C-泛型
Pytorch使用了宏来实现c的泛型API,举个栗子
1
|
#define CONCAT(A, B, C) A ## B ## C
|
例如这个宏可以用来产生Double_Matrix_mul
这个变量名
1
|
Double_Matrix CONCAT(Double_, Matrix_, mul)(Double_Matrix *A, Double_Matrix *B);
|
C的宏定义在出现#
和##
时会使用字符串替换,而不是展开,可以利用这个特性来产生函数名。同时需要一个中间宏来先展开宏名
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
#define CONCAT_2_EXPAND(A, B) A ## B // NumVector
#define CONCAT_2(A, B) CONCAT_2_EXPAND(A, B) //== #define CONCAT_2(Num,Vector) NumVector
//#define CONCAT_2(A, B) A ## B 之所以不这样写,是因为这样写无法展开A,B,直接进行的字符串替换
#define CONCAT_3_EXPAND(A, B, C) A ## B ## C
#define CONCAT_3(A, B, C) CONCAT_3_EXPAND(A, B, C)
#define Vector_(NAME) CONCAT_3(Num, Vector_, NAME)
#define Vector CONCAT_2(Num, Vector)
#define num float
#define Num Float
struct Vector
{
num *data;
int n;
};
void Vector_(add)(Vector *C, Vector *A, Vector *B) {
//codes
}
|
TH API
张量这个数学对象被TH分解为THTensor和THStorage,THTensor提供一种查看THStorage的方法,THStorage负责管理张量的存储方式。所有的THTensor类型最终都会替换成at::TensorImpl
,所有的THStorage类型也都会替换成at::StorageImpl
THTensor
1
2
3
4
5
6
7
8
9
10
|
//Aten/src/TH/THTensor.h
/* fill and zero*/
#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateAllTypes.h>
#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateHalfType.h>
#include <TH/generic/THTensorFill.h>
#include <TH/THGenerateBoolType.h>
|
刚开始看这个的时候一脸懵逼,为啥要反复include同一个文件。。
1
2
3
4
5
6
7
8
9
10
11
12
|
#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateBoolType.h"
#endif
#define scalar_t bool
#define TH_REAL_IS_BOOL
#line 1 TH_GENERIC_FILE //
#include TH_GENERIC_FILE
#undef scalar_t
#undef TH_REAL_IS_BOOL
#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif
|
以THGenerateBoolType.h
举例子,泛型的函数与宏都写在THTensorFill.h里面,include这个.h后,TH_GENERIC_FILE
得到了定义,再#include这个文件的时候,就会把这个文件里的泛型函数与宏进行展开特化,从而得到bool相关的特化函数,非常巧妙。通过这种方式,使得THTensor.h中包含了所有的c API函数定义
THStorage
Aten/src/TH/generic/THStorage.h
接口实现与THVector类似,
Util
Intrusive_ptr(c10/util/intrusive_ptr)
侵入式指针,用于解决shared_ptr无法适应的场景,区别是对象自己管理引用计数。性能好于shared_ptr 知乎链接