nn.Module 函数详解

nn.Module是所有网络模型结构的基类,无论是pytorch自带的模型,还是要自定义模型,都需要继承这个类。这个模块包含了很多子模块,如下所示,_parameters存放的是模型的参数,_buffers也存放的是模型的参数,但是是那些不需要更新的参数。带hook的都是钩子函数,详见钩子函数部分。

self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

此外,每一个模块还内置了一些常用的方法来帮助访问和操作网络。

load_state_dict() #加载模型权重参数 

parameters() #读取所有参数

named_parameters() #读取参数名称和参数

buffers() #读取self.named_buffers中的参数

named_buffers() #读取self.named_buffers中的参数名称和参数

children() #读取模型中,所有的子模型

named_children() #读取子模型名称和子模型

requires_grad_() #设置模型是否开启梯度反向传播

Parameter类

Parameter是Tensor子类,所以继承了Tensor类的属性。例如data和grad属性,可以根据data来访问参数数值,用grad来访问参数梯度。

weight_0 = nn.Parameters(torch.randn(10,10))

print(weight_0.data)
print(weight_0.grad)

定义变量的时候,nn.Parameter会被自动加入到参数列表中去

class MyModel(nn.Module):
def __init__(self):
super(MyModel,self).__init__()
self.weight1 = nn.Parameter(torch.randn(10,10))
self.weight2 = torch.randn(10,10)
def forward(self,x):
pass model = MyModel()
for name,param in model.named_parameters():
print(name) output: weight1

ParameterList

接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。

params = nn.ParameterList(
[nn.Parameter(torch.randn(10,10)) for i in range(5)]
) params.append(nn.Parameter(torch.randn(3,3)))

ParameterDict

可以像添加字典数据那样添加参数

params = nn.ParameterDict({
'linear1':nn.Parameter(torch.randn(10,5)),
'linear2':nn.Parameter(torch.randn(5,2))
})

模型构建

使用Sequential构建模型

# 写法一
net = nn.Sequential(
nn.Linear(num_inputs, 1)
# 此处还可以传入其他层
) # 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ...... # 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))
# ......
])) print(net)

自定义模型

  1. 无参数模型

下面是一个展开操作,比如将2维图像展开成一维

class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__() def forward(self,input):
return input.view(input.size(0),-1)
  1. 有参数模型

自定义一个Linear层

class MLinear(nn.Module):
def __init__(self,input,output):
super(MyLinear,self).__init__() self.w = nn.Parameter(torch.randn(input,output))
self.b = nn.Parameter(torch.randn(output)) def foward(self,x):
x = self.w @ x + self.b
return x
  1. 组合模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.l1 = nn.Linear(10,20)
self.l2 = nn.Linear(20,5) def forward(self,x):
x = self.l1(x)
x = self.l2(x) return x

ModuleList & ModuleDict

ModuleList 和 ModuleDict都是继承与nn.Module, 与Seuqential不同的是,ModuleList 和 ModuleDict没有自带forward方法,所以只能作为一个模块和其他自定义方法进行组合。下面是使用示例:

class MyModuleList(nn.Module):
def __init__(self):
super(MyModuleList, self).__init__()
self.linears = nn.ModuleList(
[nn.Linear(10, 10) for i in range(3)]
)
def forward(self, x):
for linear in self.linears:
x = linear(x)
return x class MyModuleDict(nn.Module):
def __init__(self):
super(MyModuleDict, self).__init__()
self.linears = nn.ModuleDict({
"linear1":nn.Linear(10,10),
"linear2":nn.Linear(10,10)
})
def forward(self, x):
x = self.linears["linear1"](x)
x = self.linears["linear2"](x)
return x

Pytorch系列:(三)模型构建的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. pytorch入门2.1构建回归模型初体验(模型构建)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  3. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  4. pytorch入门2.0构建回归模型初体验(数据生成)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gulp专家

    系列目录 前端构建大法 Gulp 系列 (一):为什么需要前端构建 前端构建大法 Gulp 系列 (二):为什么选择gulp 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gul ...

  6. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  7. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  8. Web 开发人员和设计师必读文章推荐【系列三十】

    <Web 前端开发精华文章推荐>2014年第9期(总第30期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML5 ...

  9. CSS3之简易的3D模型构建[原创开源]

    CSS3之简易的3D模型构建[开源分享] 先上一张图(成果图):这个是使用 3D建模空间[源码之一] 制作出来的模型之一 当然这是一部分模型特写, 之前还制作过枪的3D模型等等. 感兴趣的朋友可以自己 ...

随机推荐

  1. SpringBoot解决特殊符号 []报400问题

    当遇到特殊符号传递给后台时,如果不加处理,就会报400的错误,解决办法有两种. 1.前台解决 前台解决的方法就是把这些特殊符号转义,转义之后浏览器和后台都可以识别. //对特殊字符进行转义 encod ...

  2. 关于 HTTP 后端人员需要了解的 20+ 图片!

    前言 当您网上冲浪时,HTTP 协议无处不在.当您浏览网页.获取一张图片.一段视频时,HTTP 协议就正在发生. 本篇将尽可能用简短的例子和必要的说明来让您了解基础的 HTTP 知识. 目录: 什么是 ...

  3. go 语言 如何发送微信信息到自己手机

    使用  wxmgo 包可以把微信信息发送到自己的手机上.第一步: go get github.com/rehylas/wxmgo 第二步: import ( "fmt" wxm & ...

  4. Android Studio|IntelliJ IDEA Git使用小技巧

    一 分支管理 1. 新建分支 在master的基础上创建新分支dev 2. 推送分支 将新建的分支dev推送到远程 3. 切换分支 4. 合并分支 当我们在dev分支完成代码修改并测试通过后 需要将d ...

  5. Hi3559AV100 NNIE RFCN开发:V4L2->VDEC->VPSS->NNIE->VGS->VO系统整体动态调试实现

    下面随笔将给出Hi3559AV100 NNIE RFCN开发:V4L2->VDEC->VPSS->NNIE->VGS->VO系统整体动态调试实现,最终的效果是:USB摄像 ...

  6. 2020年12月-第02阶段-前端基础-CSS Day07

    CSS Day07 CSS高级技巧 理解 能说出元素显示隐藏最常见的写法 能说出精灵图产生的目的 能说出去除图片底侧空白缝隙的方法 应用 能写出最常见的鼠标样式 能使用精灵图技术 能用滑动门做导航栏案 ...

  7. swaks制作钓鱼邮件

      一.在网站:https://bccto.me/ 申请一个十分钟的邮箱 二.使用命令行,命令行解释如下: --from hacker@qq.com //发件人的邮箱 --ehlo qq.com // ...

  8. 死磕Spring之IoC篇 - @Bean 等注解的实现原理

    该系列文章是本人在学习 Spring 的过程中总结下来的,里面涉及到相关源码,可能对读者不太友好,请结合我的源码注释 Spring 源码分析 GitHub 地址 进行阅读 Spring 版本:5.1. ...

  9. Java并发编程之同步/并发集合

    同步集合 Java中同步集合如下: Vector:基于数组的线程安全集合,扩容默认增加1倍(ArrayList50%) Stack:继承于Vector,基于动态数组实现的一个线程安全的栈 Hashta ...

  10. 攻防世界 reverse tt3441810

    tt3441810 tinyctf-2014 附件给了一堆数据,将十六进制数据部分提取出来, flag应该隐藏在里面,(这算啥子re,) 保留可显示字符,然后去除填充字符(找规律 0.0) 处理脚本: ...