Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单。

本文主要关注 Module 类的内部是怎么样的。

初始化方法中做了什么
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
1
2
3
4
5
6
7
8
9
这是 Module 的初始化方法:

self._parameters 用来存放注册的 Parameter 对象
self._buffers 用来存放注册的 Buffer 对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)
self._modules 用来保存注册的 Module 对象。
self.training 标志位,用来表示是不是在 training 状态下
...hooks 用来保存 注册的 hook
__setattr__ 与 __getattr__
__setattr__ 每次给属性赋值的时候,都会调用这个方法。

__setattr__ 的代码比较多,我们一点一点看。

remove_from :工具函数, 用来从 self.__dict__, self._buffers, self._modules 中删除对象。
第一种情况: value 的类型是 Paramter

从 三大 字典中将 同名的 对象删掉
然后,注册 paramter
第二种情况: value不是 Parameter对象, name在 self._parameter 中

self._parameters[name] = None
已经考虑了 value 是 Parameter对象,剩下的就是考虑 value 为 buffer或 Module 了

第三种情况:value不是 Parameter对象, value 为 Module 对象

从三大字典里面移除同名 对象
然后直接向 self._modules 字典里添加 value
第四种情况:value不是Parameter对象, value不为 Module对象, 但是 name 在 self._modules 里

self._modules[name]=None
第五种情况:value不是Parameter对象, value不为 Module对象, name 存在 self._buffers 里

self._buffers[name]=None
最后一种情况: 就是 普通的属性了。

def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]

params = self.__dict__.get('_parameters')

if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not torch.is_tensor(value):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
__getattr__ : 当获取 self.__dict__ 中没有的键所对应的值的时候,就会调用这个方法

因为 parameter, module, buffer 的键值对存在与 self._parameters, self._modules, self.buffer 中,所以,当想获取这些 值时, 就会调用这个方法。

def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
register_parameter
向模型中注册 Parameter

def register_parameter(self, name, param):
"""Adds a parameter to the module.

The parameter can be accessed as an attribute using given name.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another variable, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Module.training 标志 如何影响 前向过程
从nn.Dropout 来看 Module.training

class Dropout(Module):
def __init__(self, p=0.5, inplace=False):
super(Dropout, self).__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace

def forward(self, input):
return F.dropout(input, self.p, self.training, self.inplace)
1
2
3
4
5
6
7
8
9
10
11
可以看出,在forward 过程中,直接获取,父类的training的值。

我们 通常通过 module.train() 和 module.eval() 来切换模型的 训练测试阶段。

def train(self, mode=True):
"""Sets the module in training mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = mode

for module in self.children():
# 递归调用子模块 train 函数, 来设定所有 module 的 training 值。
module.train(mode)
return self
1
2
3
4
5
6
7
8
9
10
需要注意的是:module.eval() 仅仅设置 module 的 training 属性,如果我们想获得最快的推断速度, 还需要 设置 输入 Variable的volatile 属性为 True。

参考资料
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
---------------------
作者:ke1th
来源:CSDN
原文:https://blog.csdn.net/u012436149/article/details/78281553
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch学习笔记(十二):详解 Module 类的更多相关文章

  1. expect学习笔记及实例详解【转】

    1. expect是基于tcl演变而来的,所以很多语法和tcl类似,基本的语法如下所示:1.1 首行加上/usr/bin/expect1.2 spawn: 后面加上需要执行的shell命令,比如说sp ...

  2. python3.4学习笔记(十二) python正则表达式的使用,使用pyspider匹配输出带.html结尾的URL

    python3.4学习笔记(十二) python正则表达式的使用,使用pyspider匹配输出带.html结尾的URL实战例子:使用pyspider匹配输出带.html结尾的URL:@config(a ...

  3. Go语言学习笔记十二: 范围(Range)

    Go语言学习笔记十二: 范围(Range) rang这个关键字主要用来遍历数组,切片,通道或Map.在数组和切片中返回索引值,在Map中返回key. 这个特别像python的方式.不过写法上比较怪异使 ...

  4. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

  5. java jvm学习笔记十二(访问控制器的栈校验机制)

    欢迎装载请说明出处:http://blog.csdn.net/yfqnihao 本节源码:http://download.csdn.net/detail/yfqnihao/4863854 这一节,我们 ...

  6. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  7. Docker技术入门与实战 第二版-学习笔记-3-Dockerfile 指令详解

    前面已经讲解了FROM.RUN指令,还提及了COPY.ADD,接下来学习其他的指令 5.Dockerfile 指令详解 1> COPY 复制文件 格式: COPY  <源路径> .. ...

  8. Redis学习笔记4-Redis配置详解

    在Redis中直接启动redis-server服务时, 采用的是默认的配置文件.采用redis-server   xxx.conf 这样的方式可以按照指定的配置文件来运行Redis服务.按照本Redi ...

  9. (C/C++学习笔记) 十二. 指针

    十二. 指针 ● 基本概念 位系统下为4字节(8位十六进制数),在64位系统下为8字节(16位十六进制数) 进制表示的, 内存地址不占用内存空间 指针本身是一种数据类型, 它可以指向int, char ...

随机推荐

  1. jnhs-springmvc 请求组合不正确,比如请求路径出现两次

    初学springmvc 向后台传参数,结果发现,一直404 404后没有路径,说明 没有进入controller 仔细一看,请求路径不对,重复出现 页面是这样写的 页面是这样写的 原因是,请求链接和当 ...

  2. winform应用程序异常处理

    对于winform应用程序补抓异常信息,我们经常用到得try catch. 如果代码中在某个地方执行异常,但是没有加try catch,这个时候就需要做一些全局异常捕捉. 怎么做到全局异常捕捉.win ...

  3. RSA 2019安全大会:企业资产管理成行业新风向标,云上安全占优势

    美国时间3月4-8日,国际知名信息安全峰会RSA Conference在美国旧金山开幕,云安全及云可以为企业提供更可靠的资产管理方式成为大会热点. 此次峰会共吸引全球700多家机构参展,其中近42%为 ...

  4. bzoj 3209 花神的数论题——二进制下的数位dp

    题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3209 可以枚举 “1的个数是...的数有多少个” ,然后就是用组合数算在多少位里选几个1. ...

  5. 火狐下button标签子元素无法点击

    button下元素点击事件:在chrome和safari下每个a标签可以点击,在火狐下a标签无法点击. <button> <a href="javascript:;&quo ...

  6. hdu 1711Number Sequence (KMP入门,子串第一次出现的位置)

    Number Sequence Time Limit: 10000/5000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) ...

  7. Spring Boot:Boot2.0版本整合Neo4j

    前面介绍了Boot 1.5版本集成Neo4j,Boot 2.0以上版本Neo4j变化较大. 场景还是电影人员关系 Boot 2.0主要变化 GraphRepository在Boot2.0下不支持了,调 ...

  8. vsync信号产生与分发

    以下分析基于android 4.4代码 vsync信号的产生.分发涉及到以下几个类,先主要了解下他们各自的功能: HWComposer:产生hardware vsync,post fb VSyncTh ...

  9. LintCode_111 爬楼梯

    题目 假设你正在爬楼梯,需要n步你才能到达顶部.但每次你只能爬一步或者两步,你能有多少种不同的方法爬到楼顶部? 比如n=3,中不同的方法 返回 3 1 2 3 5 8 13... step[2] = ...

  10. vue 保存数组和对象, 避免双向绑定影响

    很多时候需要保存数据然后复用该数据,因vue的双向绑定总是不能保存原始数据 随笔记录解决方式 1. 不要把变量放置在data中 2. 保存至新的变量 object :   let obj= Objec ...