『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
在前面的例子中,基本上都是将每一层的输出直接作为下一层的输入,这种网络称为前馈传播网络(feedforward neural network)。对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。
一、nn.Sequential()对象
nn.Sequential()对象是类似keras的前馈模型的对象,可以为之添加层实现前馈神经网络。
1、模型建立方式
第一种写法:
nn.Sequential()对象.add_module(层名,层class的实例)
net1 = nn.Sequential()
net1.add_module('conv', nn.Conv2d(3, 3, 3))
net1.add_module('batchnorm', nn.BatchNorm2d(3))
net1.add_module('activation_layer', nn.ReLU())
第二种写法:
nn.Sequential(*多个层class的实例)
net2 = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.BatchNorm2d(3),
nn.ReLU()
)
第三种写法:
nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))
from collections import OrderedDict
net3= nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 3, 3)),
('batchnorm', nn.BatchNorm2d(3)),
('activation_layer', nn.ReLU())
]))
2、检查以及调用模型
查看模型
print对象即可
print('net1:', net1)
print('net2:', net2)
print('net3:', net3)
net1: Sequential(
(conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(activation_layer): ReLU()
)
net2: Sequential(
(0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(2): ReLU()
)
net3: Sequential(
(conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(activation_layer): ReLU()
)
提取子Module对象
# 可根据名字或序号取出子module
net1.conv, net2[0], net3.conv
(Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))
调用模型
可以直接网络对象(输入数据),也可以使用上面的Module子对象分别传入(input)。
input = V(t.rand(1, 3, 4, 4))
output = net1(input)
output = net2(input)
output = net3(input)
output = net3.activation_layer(net1.batchnorm(net1.conv(input)))
二、nn.ModuleList()对象
ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。
建立以及使用方法如下,
modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
input = V(t.randn(1, 3))
for model in modellist:
input = model(input)
# 下面会报错,因为modellist没有实现forward方法
# output = modelist(input)
和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.list = [nn.Linear(3, 4), nn.ReLU()]
self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
def forward(self):
pass
model = MyModule()
print(model)
MyModule(
(module_list): ModuleList(
(0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
)
)
for name, param in model.named_parameters():
print(name, param.size())
('module_list.0.weight', torch.Size([3, 3, 3, 3]))
('module_list.0.bias', torch.Size([3]))
可见,list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。
除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。
『PyTorch』第九弹_前馈网络简化写法的更多相关文章
- 『MXNet』第九弹_分类器以及迁移学习DEMO
解压文件命令: with zipfile.ZipFile('../data/kaggle_cifar10/' + fin, 'r') as zin: zin.extractall('../data/k ...
- 『TensorFlow』第九弹_图像预处理_不爱红妆爱武装
部分代码单独测试: 这里实践了图像大小调整的代码,值得注意的是格式问题: 输入输出图像时一定要使用uint8编码, 但是数据处理过程中TF会自动把编码方式调整为float32,所以输入时没问题,输出时 ...
- 『PyTorch』第二弹_张量
参考:http://www.jianshu.com/p/5ae644748f21# 几个数学概念: 标量(Scalar)是只有大小,没有方向的量,如1,2,3等 向量(Vector)是有大小和方向的量 ...
- 『PyTorch』第一弹_静动态图构建if逻辑对比
对比TensorFlow和Pytorch的动静态图构建上的差异 静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if.while.for-loop等结构,而 ...
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
- 『MXNet』第一弹_基础架构及API
MXNet是基础,Gluon是封装,两者犹如TensorFlow和Keras,不过得益于动态图机制,两者交互比TensorFlow和Keras要方便得多,其基础操作和pytorch极为相似,但是方便不 ...
- 『TensorFlow』第二弹_线性拟合&神经网络拟合_恰是故人归
Step1: 目标: 使用线性模拟器模拟指定的直线:y = 0.1*x + 0.3 代码: import tensorflow as tf import numpy as np import matp ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
随机推荐
- linux常用命令:crontab 命令
前一天学习了 at 命令是针对仅运行一次的任务,循环运行的例行性计划任务,linux系统则是由 cron (crond) 这个系统服务来控制的.Linux 系统上面原本就有非常多的计划性工作,因此这个 ...
- Instagram 在 PyCon 2017 的演讲摘要
Instagram 在 PyCon 2017 的演讲摘要 PyCon 简介 PyCon 是全世界最大的以 Python 编程语言 为主题的技术大会.大会由 Python 社区组织,每年举办一次.在大会 ...
- 44 道 JavaScript 难题
JavaScript Puzzlers原文 1. ["1", "2", "3"].map(parseInt) 答案:[1, NaN, N ...
- Js基础知识4-函数的三种创建、四种调用(及关于new function()的解释)
在js中,函数本身属于对象的一种,因此可以定义.赋值,作为对象的属性或者成为其他函数的参数.函数名只是函数这个对象类的引用. 函数定义 // 函数的三种创建方法(定义方式) function one( ...
- 原 用Tomcat服务器配置https双向认证过程实战
什么是https? 百度百科足够解释它:http://baike.baidu.com/view/14121.htm 工具:keytool (Windows下路径:%JAVA_HOME%/bin/key ...
- Contiki源码+原理+功能+编程+移植+驱动+网络(转)
源:Contiki源码+原理+功能+编程+移植+驱动+网络 请链接:http://www.rimelink.com/nd.jsp? id=31&_np=105_315 假设您对于用Contik ...
- SQL学习之SQL注入学习总结
所谓SQL注入,就是通过把SQL命令插入到Web表单提交或输入域名或页面请求的查询字符串,最终达到欺骗服务器执行恶意的SQL命令. 测试数据库 我们本文就以如下数据库作为测试数据库,完成我们的注入分析 ...
- Kafka学习之(五)搭建kafka集群之Zookeeper集群搭建
Zookeeper是一种在分布式系统中被广泛用来作为:分布式状态管理.分布式协调管理.分布式配置管理.和分布式锁服务的集群.kafka增加和减少服务器都会在Zookeeper节点上触发相应的事件kaf ...
- python监控端口脚本[jkport1.0.py]
此脚本根据端口判断进程是否存活, 如果有指定的端口就证明进程是没问题的, 如果检测不到端口就是说业务进程已经挂掉了, 此时自动重启程序, 不多说下面请看脚本 创建脚本 我这里模拟的是nginx, 监控 ...
- ArrayList初始化的4种方法
In the last post we discussed about class ArrayList in Javaand it’s important methods. Here we are s ...