nn.moduleList 和Sequential由来、用法和实例 —— 写网络模型
对于cnn前馈神经网络如果前馈一次写一个forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。
目录
一、nn.Sequential()对象
1、模型建立方式
第一种写法:
第二种写法:
第三种写法:
2、检查以及调用模型
查看模型
根据名字或序号提取子Module对象
调用模型
二、nn.ModuleList()对象
为什么有他?
什么时候用?
和list的区别?
1. extend和append方法
2. 建立以及使用方法
3. yolo v3构建网络
一、nn.Sequential()对象
建立nn.Sequential()对象,必须小心确保一个块的输出大小与下一个块的输入大小匹配。基本上,它的行为就像一个nn.Module。
1、模型建立方式
第一种写法:
nn.Sequential()对象.add_module(层名,层class的实例)
1
2
3
4
net1 = nn.Sequential()
net1.add_module('conv', nn.Conv2d(3, 3, 3))
net1.add_module('batchnorm', nn.BatchNorm2d(3))
net1.add_module('activation_layer', nn.ReLU())
第二种写法:
nn.Sequential(*多个层class的实例)
1
2
3
4
5
net2 = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.BatchNorm2d(3),
nn.ReLU()
)
第三种写法:
nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))
1
2
3
4
5
6
from collections import OrderedDict
net3= nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 3, 3)),
('batchnorm', nn.BatchNorm2d(3)),
('activation_layer', nn.ReLU())
]))
2、检查以及调用模型
查看模型
print对象即可
1
2
3
print('net1:', net1)
print('net2:', net2)
print('net3:', net3)
net1: Sequential(
(conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(activation_layer): ReLU()
)
net2: Sequential(
(0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(2): ReLU()
)
net3: Sequential(
(conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
(activation_layer): ReLU()
)
根据名字或序号提取子Module对象
1
2
# 可根据名字或序号取出子module
net1.conv, net2[0], net3.conv
(Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))
调用模型
可以直接网络对象(输入数据),也可以使用上面的Module子对象分别传入(input)。
1
2
3
4
5
input = V(t.rand(1, 3, 4, 4))
output = net1(input)
output = net2(input)
output = net3(input)
output = net3.activation_layer(net1.batchnorm(net1.conv(input)))
二、nn.ModuleList()对象
为什么有他?
写一个module然后就写foreword函数很麻烦,所以就有了这两个。它被设计用来存储任意数量的nn. module。
什么时候用?
如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。
如果你想设计一个神经网络的层数作为输入传递。
和list的区别?
ModuleList是Module的子类,当在Module中使用它的时候,就能自动识别为子module。
当添加 nn.ModuleList 作为 nn.Module 对象的一个成员时(即当我们添加模块到我们的网络时),所有 nn.ModuleList 内部的 nn.Module 的 parameter 也被添加作为 我们的网络的 parameter。
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
1. extend和append方法
nn.moduleList定义对象后,有extend和append方法,用法和python中一样,extend是添加另一个modulelist append是添加另一个module
class LinearNet(nn.Module):
def __init__(self, input_size, num_layers, layers_size, output_size):
super(LinearNet, self).__init__()
self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
self.linears.append(nn.Linear(layers_size, output_size)
2. 建立以及使用方法
建立以及使用方法如下,
1
2
3
4
5
6
modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
input = V(t.randn(1, 3))
for model in modellist:
input = model(input)
# 下面会报错,因为modellist没有实现forward方法
# output = modelist(input)
和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,
1
2
3
4
5
6
7
8
9
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.list = [nn.Linear(3, 4), nn.ReLU()]
self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
def forward(self):
pass
model = MyModule()
print(model)
MyModule(
(module_list): ModuleList(
(0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
)
)
1
2
for name, param in model.named_parameters():
print(name, param.size())
('module_list.0.weight', torch.Size([3, 3, 3, 3]))
('module_list.0.bias', torch.Size([3]))
可见,普通list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。
除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。
3. yolo v3构建网络
首先module_list = nn.ModuleList()
然后
for index, x in enumerate(blocks[1:]):#根据不同的block 遍历module
module = nn.Sequential()
然后根据cfg读进来的数据,
module.add_module("batch_norm_{0}".format(index), bn)
module.add_module("conv_{0}".format(index), conv)
等等
module_list.append(module)
---------------------
作者:Snoopy_Dream
来源:CSDN
原文:https://blog.csdn.net/e01528/article/details/84397174
版权声明:本文为博主原创文章,转载请附上博文链接!

nn.moduleList 和Sequential由来、用法和实例 —— 写网络模型的更多相关文章
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- tf.nn.embedding_lookup TensorFlow embedding_lookup 函数最简单实例
tf.nn.embedding_lookup TensorFlow embedding_lookup 函数最简单实例 #!/usr/bin/env python # -*- coding: utf-8 ...
- java中List的用法和实例详解
java中List的用法和实例详解 List的用法List包括List接口以及List接口的所有实现类.因为List接口实现了Collection接口,所以List接口拥有Collection接口提供 ...
- Matlab 之meshgrid, interp, griddata 用法和实例
http://blog.sina.com.cn/s/blog_67f37e760101bu4e.html 实例结果http://wenku.baidu.com/link?url=SiGsFZIxuS1 ...
- Matlab 之meshgrid, interp, griddata 用法和实例(转)
http://blog.sina.com.cn/s/blog_67f37e760101bu4e.html 实例结果http://wenku.baidu.com/link?url=SiGsFZIxuS1 ...
- [转] php foreach用法和实例
PHP 4 引入了 foreach 结构,和 Perl 以及其他语言很像.这只是一种遍历数组简便方法.foreach 仅能用于数组,当试图将其用于其它数据类型或者一个未初始化的变量时会产生错误.有两种 ...
- echo命令的简单用法和实例
在CentOS 6.8版本下,通过实例的形式,展现选项和参数的灵活运用,可以简明的了解echo的用法. 一.语法:echo [SHORT-OPTION]… [STRING]… :echo [选项]…[ ...
- Vue.directive()的用法和实例
官网实例: https://cn.vuejs.org/v2/api/#Vue-directive https://cn.vuejs.org/v2/guide/custom-directive.html ...
- Liunx(centos8)下的yum的基本用法和实例
yum 命令 Yum(全称为 Yellow dog Updater, Modified)是一个在Fedora和RedHat以及CentOS中的Shell前端软件包管理器.基于RPM包管理,能够从指定的 ...
随机推荐
- img标签下多余空白BUG解决方法
在进行页面的DIV CSS排版时,遇到IE6(当然有时Firefox下也会偶遇)浏览器中的图片元素img下出现多余空白的问题绝对是常见的 对于该问题的解决方法也是“见机行事”. 1.将图片转换为块级对 ...
- 006-使用python编写一个猜数字的程序
题目:随机生成一个数字,共有三次机会对该数字进行猜测. #功能点# 1.猜错的时候给出提示,告诉用户输入的值是大了还是小了# 2.最多提供三次机会# 3.随机生成需要猜的数字答案 编写思路: 1.刚开 ...
- ifconfig配置IP地址和子网掩码
ifconfig eth0 192.168.2.10 ifconfig eth0 192.168.2.10 netmask 255.255.255.0
- day37 08-Hibernate的反向工程
反向工程:先创建表,创建好表之后,就是持久化类和映射文件可以不用你写,而且你的DAO它也可以帮你生成.但是它生成的DAO可能会多很多的方法.你可以不用那么多方法,但是它里面提供了这种的.用hibern ...
- tcpdump命令介绍
命令格式为:tcpdump [-nn] [-i 接口] [-w 储存档名] [-c 次数] [-Ae] [-qX] [-r 文件] [所欲捕获的数据内容] 参数: -nn,直接以 IP 及 Port ...
- 从0开始学习 GitHub 系列之「01.初识 GitHub
转载:http://blog.csdn.net/googdev/article/details/52787516 1. 写在前面 我一直认为 GitHub 是程序员必备技能,程序员应该没有不知道 Gi ...
- Android平台本地(离线)打包指南 - Android Studio
预备环境 AndroidStudio开发环境,要求安装Android4.0或以上(API 14)SDK. 下载HBuilder离线打包Android版SDK(5+ SDK下载). 离线打包SDK目录说 ...
- [新手必看] 17个常见的Python运行时错误
对于刚入门的Pythoner在学习过程中运行代码是或多或少会遇到一些错误,刚开始可能看起来比较费劲.随着代码量的积累,熟能生巧当遇到一些运行时错误时能够很快的定位问题原题.下面整理了常见的17个错误, ...
- C++ 之手写memcpy
#include<iostream>#include<cstdio>using namespace std; void* mymemcpy(void* dst, const v ...
- 2019.10.22 用TCP实现服务端并发接收
client import socket client = socket.socket() client.connect( ('127.0.0.1',8888) ) while 1: msg = in ...