Pytorch源码阅读-ch1-tensor&optim
文章目录
开个新坑, pytorch源码阅读. 从python代码开始读起, 炼丹师第一步~
torch/
1.tensor.py
继承自torch._C._TensorBase
, 包括各种操作,TODO:随后看cpp代码
__abs__, __iter__
之类的内建方法requires_grad
属性是否需要求导backward(self, gradient=None, retain_graph=None, create_graph=False)
retain_graph表示是否在backward之后free内存register_hook(self, hook) 每次
gradients
被计算的时候,这个hook
都被调用。返回的handle提供remove hook的能力1 2 3 4 5 6
v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) h = v.register_hook(lambda grad: grad * 2) # double the gradient v.backward(torch.Tensor([1, 1, 1])) #先计算原始梯度,再进hook,获得一个新梯度。 print(v.grad.data) #output [2,2,2] h.remove() # removes the hook, 返回的句柄
2.random.py //TODO default_generator
3.serialization.py 模型的load, store等方法
torch/optim
一系列优化方法的集合, 基类是optimizer.py, 其余op都是继承这个类, 基础上实现op.step(), 初始化默认参数由__init__
提供. 包括SGD, Adam, RMSProp等, 以SGD为例:
|
|
内部方法
state_dict() & load_state_dict()
更新state, param两个成员, 提供serialize的方法. 理解是可以训练到某个过程中进行op参数的存储, 下次可以继续, 避免训练失败重新训练
add_param_group()
transfer learning中将freeze固定层的参数加入训练时, 可以用该方法.
lr_scheduler
用来进行lr的调整, 动态decay
|
|
==tips==:
- id(k)获取object的单一标识,作为dict的key.
- isinstance(obj, class or tuple) 判断obj是否是class的实例
文章作者 Sun.StriKE
上次更新 2019-03-27 (417d97f)