[pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系
https://blog.csdn.net/GZHermit/article/details/78730856
nn
和nn.functional
之间的差别如下,我们以conv2d的定义为例
torch.nn.Conv2d
import torch.nn.functional as F
class Conv2d(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias) def forward(self, input):
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
torch.nn.functional.conv2d
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1): if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim())) f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
_pair(0), groups, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
区别:
1. nn.Conv2d是一个类;F.conv2d是一个函数
联系:
nn.Conv2d的forword()函数是用F.conv2d()实现的,两者功能并无区别。
(在Module
类里的__call__
实现了forward()
函数的调用,所以当实例化nn.Conv2d
类时,forward()
函数也被执行了,详细可阅读torch源码)
为什么要有这样的两种实现方式同时存在呢?
原因其实在于,为了兼顾灵活性和便利性。
在建图过程中,往往有两种层,一种如全连接层,卷积层等,当中有 Variable, 另一种如 Pooling层,ReLU层,当中没有 Variable.
如果所有的层都用 nn.functional 来定义,那么所有的Variable, 如 weights, bias 等,都需要用户手动定义,非常不便;
如果所有的层都用 nn 来定义,那么即便是简单的计算都需要建类来做,而这些可以用更为简单的函数来代替。
综上,在定义网络的时候,如果层内有 Variable, 那么用 nn 定义, 反之,则用 nn.functional定义。
2. ‘model.eval()’ vs ‘with torch.no_grad()’
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615
1. model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval model instead of training mode.
model.eval()会告知模型中的所有layers, 目前处在eval模式,batchnorm和dropout层等都会在eval模式中工作。
2. torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script).
torch.no_grad() 会影响 autograd 引擎,并关闭它。这样会降低内存的使用并且加速计算。但是将不可以使用backprop.
3. nn.Sequential() vs nn.moduleList
https://blog.csdn.net/e01528/article/details/84397174
对于cnn前馈神经网络如果前馈一次写一个forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。
其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。
ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。
3.1 nn.Sequential()
1. 模型的建立方式:
import torch
import torch.nn as nn
from torch.autograd import Variable ''' nn.Sequential
''' net1 = nn.Sequential()
net1.add_module('conv', nn.Conv2d(3, 3, 3))
# net1.add_module('conv2', nn.Conv2d(3, 3, 2))
net1.add_module('batchnorm', nn.BatchNorm2d(3))
net1.add_module('activation_layer', nn.ReLU()) print("net1:")
print(net1) # net1:
# Sequential(
# (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (activation_layer): ReLU()
# ) net2 = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.BatchNorm2d(3),
nn.ReLU()
) print("net2:")
print(net2) # net2:
# Sequential(
# (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# ) from collections import OrderedDict
net3 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 3, 3)),
('batchnorm', nn.BatchNorm2d(3)),
('activation_layer', nn.ReLU())
])) print("net3:")
print(net3) # net3:
# Sequential(
# (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (activation_layer): ReLU()
# )
2. 获取子Module对象
# get the sub module by the name or index
print("Get the sub module by the name or index:")
print(net1.conv)
print(net2[0])
print(net3.conv) # Get the sub module by the name or index:
# Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
3. 调用模型
# use the model
input = Variable(torch.rand(1, 3, 4, 4))
output1 = net1(input)
output2 = net2(input)
output3 = net3(input)
output4 = net3.activation_layer(net1.batchnorm(net1.conv(input)))
print("output1:", output1)
print("output2:", output2)
print("output3:", output3)
print("output4:", output4)
# output1: tensor([[[[0.0000, 0.1066],
# [0.0075, 0.1379]], # [[0.0558, 0.9517],
# [0.0000, 0.0000]], # [[0.5355, 0.0000],
# [0.4478, 0.0000]]]], grad_fn=<ThresholdBackward0>)
# output2: tensor([[[[0.4227, 0.3509],
# [0.0868, 0.0000]], # [[0.0000, 0.0034],
# [0.0038, 0.0000]], # [[0.0000, 0.0000],
# [0.4002, 0.1882]]]], grad_fn=<ThresholdBackward0>)
# output3: tensor([[[[0.0000, 0.0000],
# [0.4779, 0.0000]], # [[0.0000, 1.5064],
# [0.0000, 0.1515]], # [[0.7417, 0.0000],
# [0.3366, 0.0000]]]], grad_fn=<ThresholdBackward0>)
# output4: tensor([[[[0.0000, 0.1066],
# [0.0075, 0.1379]], # [[0.0558, 0.9517],
# [0.0000, 0.0000]], # [[0.5355, 0.0000],
# [0.4478, 0.0000]]]], grad_fn=<ThresholdBackward0>)
3.2 nn.moduleList
它被设计用来存储任意数量的nn. module。
如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。
1. 可以采用迭代或下标索引方式获取Module
# 1. support index and enumerate
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) def forward(self, x):
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
2. extend 和 append方法
nn.moduleList定义对象后,有extend和append方法,用法和python中一样。
extend是添加另一个modulelist ;
append是添加另一个module。
# 2. extend a modulelist; attend a module
class LinearNet(nn.Module):
"""docstring for LinearNet"""
def __init__(self, input_size, num_layers, layers_size, output_size):
super(LinearNet, self).__init__()
self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, num_layers - 1)])
self.linears.append(nn.Linear(layers_size, output_size)) model1 = LinearNet(5, 3, 4, 2)
print("---model LinearNet---")
print(model1)
print() # ---model LinearNet---
# LinearNet(
# (linears): ModuleList(
# (0): Linear(in_features=5, out_features=4, bias=True)
# (1): Linear(in_features=4, out_features=4, bias=True)
# (2): Linear(in_features=4, out_features=2, bias=True)
# )
# )
3. 建立以及使用方法
# 3. create and use -- not implement the forward
modellist = nn.ModuleList([nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2)])
input = Variable(torch.randn(1, 3))
for model in modellist:
input = model(input) # output = modellist(input) --> wrong 因为modellist没有实现forward方法
4. ModuleList与list的区别
普通list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。
除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。
class MyModule_list(nn.Module):
"""docstring for MyModule_list"""
def __init__(self):
super(MyModule_list, self).__init__()
self.list = [nn.Linear(3, 4), nn.ReLU()]
self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()]) def forward(self):
pass
model = MyModule_list()
print(model) # MyModule_list(
# (module_list): ModuleList(
# (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# (1): ReLU()
# )
# )
# 只有ModuleList的信息,并没有list的信息 for name, param in model.named_parameters():
print(name, param.size()) # module_list.0.weight torch.Size([3, 3, 3, 3])
# module_list.0.bias torch.Size([3])
# 只有ModuleList的信息,并没有list的信息
[pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList的更多相关文章
- PyTorch中,关于model.eval()和torch.no_grad()
一直对于model.eval()和torch.no_grad()有些疑惑 之前看博客说,只用torch.no_grad()即可 但是今天查资料,发现不是这样,而是两者都用,因为两者有着不同的作用 引用 ...
- 2、pytorch——Linear模型(最基础版,理解框架,背诵记忆)(调用nn.Modules模块)
#define y = X @ w import torch from torch import nn #第一模块,数据初始化 n = 100 X = torch.rand(n,2) true_w = ...
- [pytorch笔记] 调整网络学习率
1. 为网络的不同部分指定不同的学习率 class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__init__() self ...
- [Pytorch] pytorch笔记 <三>
pytorch笔记 optimizer.zero_grad() 将梯度变为0,用于每个batch最开始,因为梯度在不同batch之间不是累加的,所以必须在每个batch开始的时候初始化累计梯度,重置为 ...
- [Pytorch] pytorch笔记 <一>
pytorch笔记 - torchvision.utils.make_grid torchvision.utils.make_grid torchvision.utils.make_grid(tens ...
- `TypeError: torch.mm received an invalid combination of arguments - got (torch.FloatTensor, Variable),
`TypeError: torch.mm received an invalid combination of arguments - got (torch.FloatTensor, Variable ...
- Pytorch笔记 (3) 科学计算1
一.张量 标量 可以看作是 零维张量 向量 可以看作是 一维张量 矩阵 可以看作是 二维张量 继续扩展数据的维度,可以得到更高维度的张量 ————> 张量又称 多维数组 给定一个张量数据 ...
- [Pytorch] pytorch笔记 <二>
pytorch笔记2 用到的关于plt的总结 plt.scatter scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, ...
- Pytorch笔记 (3) 科学计算2
一.组织张量的元素 (1)重排张量元素 本节介绍在不改变 张量元素个数 和 各元素的值的情况下改变张量的大小 torch.Tensor类的成员方法 reshape() 参数是多个int类型的值. 如果 ...
随机推荐
- 使用antd List组件实现轮播图
import { List, Avatar, Carousel } from 'antd'; import { connect } from 'dva'; import './lamp.less' c ...
- Max History CodeForces - 938E (组合计数)
You are given an array a of length n. We define fa the following way: Initially fa = 0, M = 1; for e ...
- Spring经典高频面试题,原来是长这个样子
Spring经典高频面试题,原来是长这个样子 2019年08月23日 15:01:32 博文视点 阅读数 719 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文 ...
- MySQL性能优化(二):优化数据库的设计
原文:MySQL性能优化(二):优化数据库的设计 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.n ...
- GridView中点击某行的任意位置就选中该行
GridView中点击某行的任意位置就选中该行 -- :: 分类: 第一步:添加选择列 点击GridView右边小尖头,双击CommandField,选中"选择",添加,将起设置为不可见: 第二步:处 ...
- 这十个MySQL经典错误
今天就给大家列举 MySQL 数据库中,最经典的十大错误案例,并附有处理问题的解决思路和方法,希望能给刚入行,或数据库爱好者一些帮助,今后再遇到任何报错,我们都可以很淡定地去处理.学习任何一门技术的同 ...
- Java高并发程序设计学习笔记(二):多线程基础
转自:https://blog.csdn.net/dataiyangu/article/details/86226835# 什么是线程?线程的基本操作线程的基本操作新建线程调用run的一种方式调用ru ...
- Tensorflow源码编译常见问题点总结
Tensorflow源码编译分两种:一种是本地源码编译,另一种是针对ARM平台的源码编译. 接下来分别介绍: 一.本地编译 本地编译时,使用的编译工具是本地GCC. 一般会碰到以下问题: 第1个:ex ...
- PowerDesigner连接 MySQL 生成 ER图
powerdesigner 16.5 http://www.pcsoft.com.cn/soft/27495.html jdk 1.8 32位 https://mirrors.huaweicloud. ...
- Jmeter多接口测试之参数传递
接口测试包含单接口测试和多接口测试,通过组合多个接口实现一组功能的验证称为多接口测试,单接口重在单个接口多种请求组合的响应断言,多接口重在组合不同接口,实现流程的串联和验证.多接口测试涉及到接口之间参 ...