nn.Module基类的构造函数:

def __init__(self):
self._parameters = OrderedDict()
self._modules = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self.training = True

其中每个属性的解释如下:

  • _parameters:字典,保存用户直接设置的parameter,self.param1 = nn.Parameter(t.randn(3, 3))会被检测到,在字典中加入一个key为'param',value为对应parameter的item。而self.submodule = nn.Linear(3, 4)中的parameter则不会存于此。
  • _modules:子module,通过self.submodel = nn.Linear(3, 4)指定的子module会保存于此。
  • _buffers:缓存。如batchnorm使用momentum机制,每次前向传播需用到上一次前向传播的结果。
  • _backward_hooks_forward_hooks:钩子技术,用来提取中间变量,类似variable的hook。
  • training:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

上述几个属性中,_parameters_modules_buffers这三个字典中的键值,都可以通过self.key方式获得,效果等价于self._parameters['key'].

定义一个Module,这个Module即包含自己的Parameters有包含子Module及其Parameters,

import torch as t
from torch import nn
from torch.autograd import Variable as V class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 等价与self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))
self.param1 = nn.Parameter(t.rand(3, 3))
self.submodel1 = nn.Linear(3, 4)
def forward(self, input):
x = self.param1.mm(input)
x = self.submodel11(x)
return x
net = Net()

一、_modules

# 打印网络对象的话会输出子module结构
print(net)

Net(
(submodel1): Linear(in_features=3, out_features=4)
)

# ._modules输出的也是子module结构,不过数据结构和上面的有所不同
print(net.submodel1)
print(net._modules) # 字典子类

Linear(in_features=3, out_features=4)
OrderedDict([('submodel1', Linear(in_features=3, out_features=4))])

for name, submodel in net.named_modules():
    print(name, submodel)

 Net(
(submodel1): Linear(in_features=3, out_features=4)
)
submodel1 Linear(in_features=3, out_features=4)

print(list(net.named_modules())) # named_modules其实是包含了本层的module集合

[('', Net(
(submodel1): Linear(in_features=3, out_features=4)
)), ('submodel1', Linear(in_features=3, out_features=4))]

二、_parameters

# ._parameters存储的也是这个结构
print(net.param1)
print(net._parameters) # 字典子类,仅仅包含直接定义的nn.Parameters参数

Parameter containing:
0.6135 0.8082 0.4519
0.9052 0.5929 0.2810
0.6825 0.4437 0.3874
[torch.FloatTensor of size 3x3] OrderedDict([('param1', Parameter containing:
0.6135 0.8082 0.4519
0.9052 0.5929 0.2810
0.6825 0.4437 0.3874
[torch.FloatTensor of size 3x3]
)])

for name, param in net.named_parameters():
    print(name, param.size())

param1 torch.Size([3, 3])
submodel1.weight torch.Size([4, 3])
submodel1.bias torch.Size([4])

三、_buffers

bn = nn.BatchNorm1d(2)
input = V(t.rand(3, 2), requires_grad=True)
output = bn(input)
bn._buffers
OrderedDict([('running_mean',
1.00000e-02 *
9.1559
1.9914
[torch.FloatTensor of size 2]), ('running_var',
0.9003
0.9019
[torch.FloatTensor of size 2])])

四、training

input = V(t.arange(0, 12).view(3, 4))
model = nn.Dropout()
# 在训练阶段,会有一半左右的数被随机置为0
model(input)
Variable containing:
0 2 4 0
8 10 0 0
0 18 0 22
[torch.FloatTensor of size 3x4]
model.training  = False
# 在测试阶段,dropout什么都不做
model(input)
Variable containing:
0 1 2 3
4 5 6 7
8 9 10 11
[torch.FloatTensor of size 3x4]

Module.train()、Module.eval() 方法和 Module.training属性的关系

print(net.training, net.submodel1.training)
net.train() # 将本层及子层的training设定为True
net.eval() # 将本层及子层的training设定为False
net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
net.training, net.submodel1.training
True True
(True, False)

『PyTorch』第十四弹_torch.nn.Module类属性的更多相关文章

  1. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  2. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  3. 『PyTorch』第十六弹_hook技术

    由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数. 钩子函数包括Variable的钩子和nn.Module钩子,用法相似. 一.register_hook impo ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  7. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. 查看firefox浏览器 驱动geckodriver.exe文件的版本号的方法,以及下载链接

    1-进入到geckodriver.exe文件的目录: 2-在路径栏下输入cmd: 3-命令行界面下输入:geckodriver.exe -h 可以看到文件的帮助信息,其中第一行就列出了版本号 为0.1 ...

  2. python 文件操作 练习:取得文件的最后存取时间

    #coding=utf-8 import osimport time file_atime=int(os.path.getatime('d:\\a.txt'))print "file_ati ...

  3. 使用Spring实现读写分离( MySQL实现主从复制)(转)

    本文转自:http://blog.csdn.net/jack85986370/article/details/51559232 1.  背景 我们一般应用对数据库而言都是“读多写少”,也就说对数据库读 ...

  4. Linux基础命令---which

    which 在环境变量PATH中搜索某个命令,返回命令的执行文件或者脚本位置,默认只显示第一个结果.这需要一个或多个参数.对于它的每个参数,它会打印出当在shell提示符下输入该参数时将执行的可执行文 ...

  5. python列表list

    1.通过中括号[ ]括起来,用逗号分隔每个元素,元素可以是数字.字符串.布尔值.列表.元组.字典.集合 2.列表有序(体现在每次打印结果都一样),因此可通过下标索引的方式取元素,下标从0开始,li[m ...

  6. Python3 Selenium定位不到元素常见原因及解决办法

    Python3 Selenium定位不到元素常见原因及解决办法 一.问题描述 在做web应用的自动化测试时,定位元素是必不可少的,这个过程经常会碰到定位不到元素的情况: 报错信息: no such e ...

  7. Java 第二次实验20145104 张家明

    实验二 Java面向对象程序设计 实验内容 初步掌握单元测试和TDD 理解并掌握面向对象三要素:封装.继承.多态 初步掌握UML建模 熟悉S.O.L.I.D原则 了解设计模式 实验步骤 (一)单元测试 ...

  8. 20145335郝昊《网络攻防》Exp 4 利用nmap扫描

    20145335郝昊<网络攻防>Exp 4 利用nmap扫描 实验原理 使用msf辅助模块,nmap来扫描发现局域网中的主机ip 实验步骤 首先使用命令创建一个msf所需的数据库 serv ...

  9. C语言结构体,点运算和箭头运算

    C语言有一种数据类型叫结构体,其定义格式为: struct 结构体名 { 结构体成员变量定义; }; 如: struct student { char name[20]; int age ; doub ...

  10. Python3基础 str find+index 是否存在指定字符串,有则返回第一个索引值

             Python : 3.7.0          OS : Ubuntu 18.04.1 LTS         IDE : PyCharm 2018.2.4       Conda ...