nn.Module 函数详解

nn.Module是所有网络模型结构的基类,无论是pytorch自带的模型,还是要自定义模型,都需要继承这个类。这个模块包含了很多子模块,如下所示,_parameters存放的是模型的参数,_buffers也存放的是模型的参数,但是是那些不需要更新的参数。带hook的都是钩子函数,详见钩子函数部分。

self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

此外,每一个模块还内置了一些常用的方法来帮助访问和操作网络。

load_state_dict() #加载模型权重参数 

parameters() #读取所有参数

named_parameters() #读取参数名称和参数

buffers() #读取self.named_buffers中的参数

named_buffers() #读取self.named_buffers中的参数名称和参数

children() #读取模型中,所有的子模型

named_children() #读取子模型名称和子模型

requires_grad_() #设置模型是否开启梯度反向传播

Parameter类

Parameter是Tensor子类,所以继承了Tensor类的属性。例如data和grad属性,可以根据data来访问参数数值,用grad来访问参数梯度。

weight_0 = nn.Parameters(torch.randn(10,10))

print(weight_0.data)
print(weight_0.grad)

定义变量的时候,nn.Parameter会被自动加入到参数列表中去

class MyModel(nn.Module):
def __init__(self):
super(MyModel,self).__init__()
self.weight1 = nn.Parameter(torch.randn(10,10))
self.weight2 = torch.randn(10,10)
def forward(self,x):
pass model = MyModel()
for name,param in model.named_parameters():
print(name) output: weight1

ParameterList

接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。

params = nn.ParameterList(
[nn.Parameter(torch.randn(10,10)) for i in range(5)]
) params.append(nn.Parameter(torch.randn(3,3)))

ParameterDict

可以像添加字典数据那样添加参数

params = nn.ParameterDict({
'linear1':nn.Parameter(torch.randn(10,5)),
'linear2':nn.Parameter(torch.randn(5,2))
})

模型构建

使用Sequential构建模型

# 写法一
net = nn.Sequential(
nn.Linear(num_inputs, 1)
# 此处还可以传入其他层
) # 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ...... # 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))
# ......
])) print(net)

自定义模型

  1. 无参数模型

下面是一个展开操作,比如将2维图像展开成一维

class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__() def forward(self,input):
return input.view(input.size(0),-1)
  1. 有参数模型

自定义一个Linear层

class MLinear(nn.Module):
def __init__(self,input,output):
super(MyLinear,self).__init__() self.w = nn.Parameter(torch.randn(input,output))
self.b = nn.Parameter(torch.randn(output)) def foward(self,x):
x = self.w @ x + self.b
return x
  1. 组合模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.l1 = nn.Linear(10,20)
self.l2 = nn.Linear(20,5) def forward(self,x):
x = self.l1(x)
x = self.l2(x) return x

ModuleList & ModuleDict

ModuleList 和 ModuleDict都是继承与nn.Module, 与Seuqential不同的是,ModuleList 和 ModuleDict没有自带forward方法,所以只能作为一个模块和其他自定义方法进行组合。下面是使用示例:

class MyModuleList(nn.Module):
def __init__(self):
super(MyModuleList, self).__init__()
self.linears = nn.ModuleList(
[nn.Linear(10, 10) for i in range(3)]
)
def forward(self, x):
for linear in self.linears:
x = linear(x)
return x class MyModuleDict(nn.Module):
def __init__(self):
super(MyModuleDict, self).__init__()
self.linears = nn.ModuleDict({
"linear1":nn.Linear(10,10),
"linear2":nn.Linear(10,10)
})
def forward(self, x):
x = self.linears["linear1"](x)
x = self.linears["linear2"](x)
return x

Pytorch系列:(三)模型构建的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. pytorch入门2.1构建回归模型初体验(模型构建)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  3. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  4. pytorch入门2.0构建回归模型初体验(数据生成)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gulp专家

    系列目录 前端构建大法 Gulp 系列 (一):为什么需要前端构建 前端构建大法 Gulp 系列 (二):为什么选择gulp 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gul ...

  6. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  7. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  8. Web 开发人员和设计师必读文章推荐【系列三十】

    <Web 前端开发精华文章推荐>2014年第9期(总第30期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML5 ...

  9. CSS3之简易的3D模型构建[原创开源]

    CSS3之简易的3D模型构建[开源分享] 先上一张图(成果图):这个是使用 3D建模空间[源码之一] 制作出来的模型之一 当然这是一部分模型特写, 之前还制作过枪的3D模型等等. 感兴趣的朋友可以自己 ...

随机推荐

  1. 元类、orm

    目录 一.内置函数exec 二.元类 1. 什么是元类 2. 元类的作用 3. 创建类的两种方法 4. 怎么自定义创建元类 三.ORM 1. ORM中可能会遇到的问题 2. ORM中元类需要解决的问题 ...

  2. SSAS表格模型

    Analysis Services 是在决策支持和业务分析中使用的分析数据引擎 (Vertipaq) . 它为商业智能提供企业级语义数据模型功能 (BI) .数据分析和报告应用程序,如 Power B ...

  3. ElasticSearch 集群安全

    公号:码农充电站pro 主页:https://codeshellme.github.io 在安装完 ES 后,ES 默认是没有任何安全防护的. ES 的安全管理主要包括以下内容: 身份认证:鉴定访问用 ...

  4. while、do...while和for循环

    一.循环 1.1 定义 当满足一定条件的时候,重复执行某一段代码的操作 while和for和do...while是java中的循环 二.while循环 2.1 定义 int i = 0: 初始化值 w ...

  5. Java开发不懂Docker,学尽Java也枉然,阿里P8架构师手把手带你玩转Docker实战

    转: Java开发不懂Docker,学尽Java也枉然,阿里P8架构师手把手带你玩转Docker实战 Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一 ...

  6. 话说 synchronized

    一.前言 ​ 说起java的锁呀,我们先想到的肯定是synchronized[ˈsɪŋ krə naɪ zd]了 ,这个单词很拗口,会读这个单词在以后的面试中很加分(我面试过一些人 不会读 ,他们说的 ...

  7. 剑指 Offer 59 - II. 队列的最大值--滑动窗口的建模+Deque的基本使用(常用方法)

    剑指 Offer 59 - II. 队列的最大值 题目链接 package com.walegarrett; /** * @Author WaleGarrett * @Date 2020/12/3 1 ...

  8. 比较String 字符串的字节大小

    package com.ittx.edi.erp;import java.io.File;import java.io.FileWriter;import java.io.IOException;pu ...

  9. AtCoder Beginner Contest 171-175 F

    171 F - Strivore 直接把初始字符当成隔板,统计的方案数会有重复 为了避免重复情况,规定隔板字母尽可能最后出现,即在隔板字母后面不能插入含隔板字母的字符串 所以在隔板字母后插入的字符只有 ...

  10. FreeBSD jail 折腾记(一)

    创建jail目录 mkdir -p /usr/jail/ 放入基本系统 方案一 make buildworld #编译基本系统 make installworld DESTDIR=/usr/jail/ ...