pytroch nn.Module源码解析(1)
今天在写一个分类网络时,要使用nn.Sequential中的一个模块,因为nn.Sequential中模块都没有名字,我一时竟无从下笔。于是决定写这篇博客梳理pytorch的nn.Module类,看完这篇博客,你大概率可以学会:
- 提取nn.Sequential中任意一个模块
- 能初始化一个网络的所有权重,不管是随机初始化还是使用权重文件
- 对nn.Module类有个总体把握
1 __init__方法
我们先不看代码,自己小脑袋里想一想这个类应该有什么东西。既然这个类是和各种layer相关,里面一定存着各种layer,如卷积或者relu,肯定还会有layer对应的权重。没错,这些东西类里都有:
from ..backends.thnn import backend as thnn_backend
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
其中_modules和_parameters就是存储这些的,那么剩下这些奇奇怪怪的属性是什么呢?第一个奇怪的东西就是这个_backend属性,他的值是thnn_backend,好像和torch的底层代码有关,有点麻烦啊!不用怕,这个属性我们只要大概了解它的意思就可以,对之后这个模块的使用没有任何影响。举个例子来帮助我们理解,想象说我们有maxpooling层,pytorch底层会用cudnn实现,也会用cunn实现,pytorch在前向传播时会自动选择一个速度最快的实现,这里thnn_backend就是指定我们前向传播时用thnn这种实现方式。更多过于backend的直觉理解可以参考这个thread。
第二个奇怪的东西就是_buffer属性,如果我说这个buffer存储的也是网络参数,你会感觉到更加迷惑吗?一个提示:batch norm的参数。(花一分钟想一想。)batch norm中除了有参数$ alpha $和$ beta $外,还有running_mean和running_var,这些在整个学习过程中也需要存储起来但是不需要学习的参数我们就把它们存储到_buffer中。
第三个奇怪的东西就是3个钩子,这三个钩子具体的实现我可能很后面才会讲,不过为了一些好奇宝宝,如果我说forward_hook的功能是在module完成前向传播时做一些事,你能推断出其他两个钩子的功能吗?(答案)
2 children方法和modules方法
children方法和modules方法的作用是很类似的,我们先看一下children方法的代码。
def named_children(self):
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
def children(self):
for name, module in self.named_children():
yield module
我们看完代码发现children方法的作用就是把_modules遍历一遍,我们来看一下具体例子(你也可以自己在命令行中把这两个命令输入进去,尽量不要复制粘贴):
>>>model=nn.Sequential(nn.Linear(3,1), \
nn.Sequential(nn.BatchNorm2d(1), \
nn.Linear(1,3)))
>>> for m in model.children():
... print(m)
Linear(in_features=3, out_features=1, bias=True)
Sequential(
(0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
(1): Linear(in_features=1, out_features=3, bias=True)
)
如果我们还要把更里层的modules提取出来,我们就需要用到modules方法:
def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self) yield prefix, self for name, module in self._modules.items(): if module is None: continue submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix): yield mdef modules(self):
for name, module in self.named_modules():
yeild module
>>> for m in model.modules():
... print(m)
Sequential(
(0): Linear(in_features=3, out_features=1, bias=True)
(1): Sequential(
(0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
(1): Linear(in_features=1, out_features=3, bias=True)
)
)
Linear(in_features=3, out_features=1, bias=True)
Sequential(
(0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
(1): Linear(in_features=1, out_features=3, bias=True)
)
BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
Linear(in_features=1, out_features=3, bias=True)
了解了这两个方法你应该可以完成我们的第一个目标:提取nn.Sequential的任一个模块。在下一节中我们会完成我们的第二个目标,初始化权重。在结束之前给大家个小问题:有时候你只需要底层的module,而不需要module的子类,如nn.Sequential,那么怎么去除呢?
机器学习小贴士:支持向量机的意思就是我们最后选择的模型只与支持向量有关。
最后编辑于2018-10-1019:39:53 有什么错误请不吝赐教
pytroch nn.Module源码解析(1)的更多相关文章
- Cognitive Graph for Multi-Hop Reading Comprehension at Scale(ACL2019) 阅读笔记与源码解析
论文地址为:Cognitive Graph for Multi-Hop Reading Comprehension at Scale github地址:CogQA 背景 假设你手边有一个维基百科的搜索 ...
- [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现
[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 目录 [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 0x00 摘要 0x01 概述 1.1 什么是GPip ...
- [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积
[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前文回 ...
- [源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段
[源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段 目录 [源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段 0x00 摘要 0x0 ...
- [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型
[源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 目录 [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 0x00 摘要 0x01 前言 1.1 改 ...
- [源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎
[源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎 目录 [源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎 0x00 摘要 0x01 前言 1.1 ...
- [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练
[源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 目录 [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 0x00 摘要 0 ...
- [源码解析] PyTorch 流水线并行实现 (1)--基础知识
[源码解析] PyTorch 流水线并行实现 (1)--基础知识 目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 t ...
- [源码解析] PyTorch 流水线并行实现 (2)--如何划分模型
[源码解析] PyTorch 流水线并行实现 (2)--如何划分模型 目录 [源码解析] PyTorch 流水线并行实现 (2)--如何划分模型 0x00 摘要 0x01 问题 0x01 自动平衡 1 ...
随机推荐
- apply,all,bind的区别
这三个都是用来改变this指向的 call() 和apply()的第一个参数相同,就是指定的对象.这个对象就是该函数的执行上下文.call()和apply()的区别就在于,两者接收的参数不一样.cal ...
- vue父子组件生命周期执行顺序
之前写了vue的生命周期,本以为明白了vue实例在创建到显示在页面上以及销毁等一系列过程,以及各个生命周期的特点.然而今天被问到父子组件生命周期执行顺序的时候一头雾水,根本不知道怎么回事.然后写了一段 ...
- Groovy里面闭包中变量符号的查找与变量定义的限制
class a { def v1 = "v1 in a" static def v2 = "v2 in a" def v4 = "v4 in a&qu ...
- 本地Git与Github建立关联
准备 本地与Github建立连接,需要用到SSH公钥.一般安装完Git,会在用户目录中生成一个 .ssh的文件夹 如果没有此文件夹,可以通过命令创建 $ ssh-keygen -t rsa -C &q ...
- 20175315 《Java程序设计》第6周学习总结
20175215 <Java程序设计>第6周学习总结 教材学习内容总结 第七章主要讲的是内部类,匿名类,异常类等等. 内部类:Java支持在一个类中定义另一个类,称作内部类,包含内部类的类 ...
- tftp--实现服务器与客户端的下载与上传【转】
转自:https://blog.csdn.net/xiaopangzi313/article/details/9122975 版权声明:本文为博主原创文章,未经博主允许不得转载. https://bl ...
- IDEA中执行MAVEN命令打jar包
SpringBoot Jar包打包 1.工程POM配置packaging为jar. <packaging>jar</packaging> 2.增加MAVEN运行配置 添加MAV ...
- shell编程练习-打印九九乘法表(附:awk编程)
小练习,仅供参考 shell编写 #!/bin/bash for i in {1..9}do for j in {1..9} do if [ $j -le $i ] ;then echo -ne &q ...
- neo4j-cypher
cypher查询务必在需要查询的节点上加上标签,否则数据量一大查询就会非常慢(在查询时必须设置实体标签,否则不走索引),另外Neo4j索引做好了查询的优化基本上就完成了80%.需要注意index是建立 ...
- android app 的插件化、组件化、模块化开发-2
Android 插件化 ——指将一个程序划分为不同的部分,比如一般 App的皮肤样式就可以看成一个插件 Android 组件化 ——这个概念实际跟上面相差不那么明显,组件和插件较大的区别就是:组件是指 ...