『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是:
- nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Parameter
- nn.functional中的函数更像是纯函数,由def function(input)定义。
由于两者性能差异不大,所以具体使用取决于个人喜好。对于激活函数和池化层,由于没有可学习参数,一般使用nn.functional完成,其他的有学习参数的部分则使用类。但是Droupout由于在训练和测试时操作不同,所以建议使用nn.Module实现,它能够通过model.eval加以区分。
一、nn.functional函数基本使用
import torch as t
import torch.nn as nn
from torch.autograd import Variable as V input_ = V(t.randn(2, 3))
model = nn.Linear(3, 4)
output1 = model(input_)
output2 = nn.functional.linear(input_, model.weight, model.bias)
print(output1 == output2) b1 = nn.functional.relu(input_)
b2 = nn.ReLU()(input_)
print(b1 == b2)
二、搭配使用nn.Module和nn.functional
并不是什么难事,之前有接触过,nn.functional不需要放入__init__进行构造,所以不具有可学习参数的部分可以使用nn.functional进行代替。
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
# Author : Hellcat
# Time : 2018/2/11 import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
三、nn.functional函数构造nn.Module类
两者主要的区别就是对于可学习参数nn.Parameter的识别能力,所以构造时添加了识别能力即可。
class Linear(nn.Module):
def __init__(self, in_features, out_features):
# nn.Module.__init__(self)
super(Linear, self).__init__()
self.w = nn.Parameter(t.randn(out_features, in_features)) # nn.Parameter是特殊Variable
self.b = nn.Parameter(t.randn(out_features)) def forward(self, x):
# wx+b
return F.linear(x, self.w, self.b) layer = Linear(4, 3)
input = V(t.randn(2, 4))
output = layer(input)
print(output)
Variable containing:
1.7498 -0.8839 0.5314
-2.4863 -0.6442 1.1036
[torch.FloatTensor of size 2x3]
『PyTorch』第十二弹_nn.Module和nn.functional的更多相关文章
- 『PyTorch』第十六弹_hook技术
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数. 钩子函数包括Variable的钩子和nn.Module钩子,用法相似. 一.register_hook impo ...
- 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询
一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...
- 『PyTorch』第十四弹_torch.nn.Module类属性
nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...
- 『MXNet』第十二弹_再谈新建计算节点
上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
随机推荐
- 获取Linux时间函数
Linux下clock_gettime函数详解 要包含这头文件<time.h> 且在编译链接时需加上 -lrt ;因为在librt中实现了clock_gettime函数. --- stru ...
- [分享] 采用opencv_cascadetrain进行训练的步骤及注意事项 [复制链接]
http://f.dataguru.cn/thread-725364-1-1.html 很有用的一个帖子 转自:http://blog.csdn.net/xidianzhimeng/article/d ...
- (二)github的价值意义篇
为什么需要社会化编程? 如果您是程序员面试官,两者之间你会选择哪一位呢? 能查看以前所写代码的程序员 or 无法查看的程序员 精通最新软件的程序员 or 不精通的程序员 对语言或软件差异带来的不同文化 ...
- inotify工具安装配置
一.安装 1) 从内核和目录里面查看是否支持inotify [root@nfs01 ~]# uname -r 2.6.32-573.el6.x86_64 [root@nfs01 ~]# ls -l ...
- 20165211 2017-2018-2 《Java程序设计》第5周学习总结
20165211 2017-2018-2 <Java程序设计>第5周学习总结 教材学习内容总结 本周,我学习了书本上第五.六两章的内容,以下是我整理的主要知识. 第五章 内部类与异常类 内 ...
- np.tile 和np.newaxis
output array([[ 0.24747071, -0.43886742], [-0.03916734, -0.70580089], [ 0.00462337, -0.5143158 ...
- Java集合总结(List、Map、Set)
集合的引入 当我们有种需求,需要存储多个元素的结构时,我们前面讲过数组,数组可以存储.但是数组也有它的弊端,使用的时候,必须先定义好长度,也就是数组的长度是固定,不能根据我们的需求自动变长或者变短. ...
- HTML语法分析
什么是HTML htyper text markup language 即超文本标记语言HTML是一个网页的主体部分,也是一个网页的基础.因为一个网页可以没有样式,可以没有交互,但是必须要有网页需要呈 ...
- linux下去掉pdf的密码(前提:知道密码)
一.背景 Linux jello 4.16.3 SMP Thu Apr 19 07:32:02 UTC 2018 x86_64 x86_64 x86_64 GNU/Linux 二.去掉密码 2.1 先 ...
- 括号序和dfs序
记得清北讲过括号序和dfs序,忘记了 dfs序 dfs序就是dfs的顺序,这个好记 就是在dfs遍历树的时候,将每个结点开始时记录一次,结束时记录一次 而且一个子树可以表示为连续的一段, 只有子树操作 ...