Pytorch系列:(三)模型构建
nn.Module 函数详解
nn.Module是所有网络模型结构的基类,无论是pytorch自带的模型,还是要自定义模型,都需要继承这个类。这个模块包含了很多子模块,如下所示,_parameters存放的是模型的参数,_buffers也存放的是模型的参数,但是是那些不需要更新的参数。带hook的都是钩子函数,详见钩子函数部分。
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
此外,每一个模块还内置了一些常用的方法来帮助访问和操作网络。
load_state_dict() #加载模型权重参数
parameters() #读取所有参数
named_parameters() #读取参数名称和参数
buffers() #读取self.named_buffers中的参数
named_buffers() #读取self.named_buffers中的参数名称和参数
children() #读取模型中,所有的子模型
named_children() #读取子模型名称和子模型
requires_grad_() #设置模型是否开启梯度反向传播
Parameter类
Parameter是Tensor子类,所以继承了Tensor类的属性。例如data和grad属性,可以根据data来访问参数数值,用grad来访问参数梯度。
weight_0 = nn.Parameters(torch.randn(10,10))
print(weight_0.data)
print(weight_0.grad)
定义变量的时候,nn.Parameter会被自动加入到参数列表中去
class MyModel(nn.Module):
def __init__(self):
super(MyModel,self).__init__()
self.weight1 = nn.Parameter(torch.randn(10,10))
self.weight2 = torch.randn(10,10)
def forward(self,x):
pass
model = MyModel()
for name,param in model.named_parameters():
print(name)
output: weight1
ParameterList
接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。
params = nn.ParameterList(
[nn.Parameter(torch.randn(10,10)) for i in range(5)]
)
params.append(nn.Parameter(torch.randn(3,3)))
ParameterDict
可以像添加字典数据那样添加参数
params = nn.ParameterDict({
'linear1':nn.Parameter(torch.randn(10,5)),
'linear2':nn.Parameter(torch.randn(5,2))
})
模型构建
使用Sequential构建模型
# 写法一
net = nn.Sequential(
nn.Linear(num_inputs, 1)
# 此处还可以传入其他层
)
# 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ......
# 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))
# ......
]))
print(net)
自定义模型
- 无参数模型
下面是一个展开操作,比如将2维图像展开成一维
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,input):
return input.view(input.size(0),-1)
- 有参数模型
自定义一个Linear层
class MLinear(nn.Module):
def __init__(self,input,output):
super(MyLinear,self).__init__()
self.w = nn.Parameter(torch.randn(input,output))
self.b = nn.Parameter(torch.randn(output))
def foward(self,x):
x = self.w @ x + self.b
return x
- 组合模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.l1 = nn.Linear(10,20)
self.l2 = nn.Linear(20,5)
def forward(self,x):
x = self.l1(x)
x = self.l2(x)
return x
ModuleList & ModuleDict
ModuleList 和 ModuleDict都是继承与nn.Module, 与Seuqential不同的是,ModuleList 和 ModuleDict没有自带forward方法,所以只能作为一个模块和其他自定义方法进行组合。下面是使用示例:
class MyModuleList(nn.Module):
def __init__(self):
super(MyModuleList, self).__init__()
self.linears = nn.ModuleList(
[nn.Linear(10, 10) for i in range(3)]
)
def forward(self, x):
for linear in self.linears:
x = linear(x)
return x
class MyModuleDict(nn.Module):
def __init__(self):
super(MyModuleDict, self).__init__()
self.linears = nn.ModuleDict({
"linear1":nn.Linear(10,10),
"linear2":nn.Linear(10,10)
})
def forward(self, x):
x = self.linears["linear1"](x)
x = self.linears["linear2"](x)
return x
Pytorch系列:(三)模型构建的更多相关文章
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gulp专家
系列目录 前端构建大法 Gulp 系列 (一):为什么需要前端构建 前端构建大法 Gulp 系列 (二):为什么选择gulp 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gul ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)
文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...
- Web 开发人员和设计师必读文章推荐【系列三十】
<Web 前端开发精华文章推荐>2014年第9期(总第30期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML5 ...
- CSS3之简易的3D模型构建[原创开源]
CSS3之简易的3D模型构建[开源分享] 先上一张图(成果图):这个是使用 3D建模空间[源码之一] 制作出来的模型之一 当然这是一部分模型特写, 之前还制作过枪的3D模型等等. 感兴趣的朋友可以自己 ...
随机推荐
- 不同浏览器CSS样式不兼容问题
一句话,我想的太复杂了.向朋友请教才了解到,其实只要加个判断即可,首先获取到浏览器的基本信息,像什么版本啊,名称啊.默认语言啊等等,然后根据不同浏览器默认加载不同CSS样式即可,获取浏览器版本的连接如 ...
- 从Java的堆栈到Equals和==的比较
以下为链接 https://www.2cto.com/kf/201503/383832.html 栈与堆都是Java用来在Ram中存放数据的地方.与C++不同,Java自动管理栈和堆,程序员不能直接地 ...
- Ping 的工作原理你懂了,那 ICMP 你懂不懂?
计算机网络我也连载了很多篇了,大家可以在我的公众号「程序员cxuan」 或者我的 github 系统学习. 计算机网络第一篇,聊一聊网络基础 :计算机网络基础知识总结 计算机网络第二篇,聊一聊 TCP ...
- 必知必会之 Java
必知必会之 Java 目录 不定期更新中-- 基础知识 数据计量单位 面向对象三大特性 基础数据类型 注释格式 访问修饰符 运算符 算数运算符 关系运算符 位运算符 逻辑运算符 赋值运算符 三目表达式 ...
- 第31天学习打卡(File类。字符流读写文件)
File类 概念 文件,文件夹,一个file对象代表磁盘上的某个文件或者文件夹 构造方法 File(String pathname) File(String parent,String child) ...
- std和stl的关系
[前言]在写程序时,虽然一直这么用,有点疑惑为甚么引入了头文件.h还要在加上using namespace std?例如: 1 #include<iostream> 2 using nam ...
- 面试被吊打系列 - Redis原理
小张兴冲冲去面试,结果被面试官吊打! 小张: 面试官,你好.我是来参加面试的. 面试官: 你好,小张.我看了你的简历,熟练掌握Redis,那么我就随便问你几个Redis相关的问题吧.首先我的问题是,R ...
- 分布式session实现方式
一.背景 在搭建完集群环境后,不得不考虑的一个问题就是用户访问产生的session如何处理. 如果不做任何处理的话,用户将出现频繁登录的现象,比如集群中存在A.B两台服务器,用户在第一次访问网站时,N ...
- WinFrom与百度地图完美交互
如何在WinFrom上显示百度地图,也许这个很多人会说这个很容易,但是如何实现WinFrom与百度地图的完美交互,如果在Winfrom的控制输入经纬度,准确定位出在百度地图上,并且还能显示出具体的地址 ...
- 记录Java注解在JavaWeb中的一个应用实例
概述 在学习注解的时候,学了个懵懵懂懂.学了JavaWeb之后,在做Demo项目的过程中,借助注解和反射实现了对页面按钮的权限控制,对于注解才算咂摸出了点味儿来. 需求 以"角色列表&quo ...