一、继承nn.Module类并自定义层

我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类。

首先,简单实现一个Mylinear类:

from torch import nn

# Mylinear继承Module
class Mylinear(nn.Module):
# 传入输入维度和输出维度
def __init__(self,in_d,out_d):
# 调用父类构造函数
super(Mylinear,self).__init__()
# 使用Parameter类将w和b封装,这样可以通过nn.Module直接管理,并提供给优化器优化
self.w = nn.Parameter(torch.randn(out_d,in_d))
self.b = nn.Parameter(torch.randn(out_d)) # 实现forward函数,该函数为默认执行的函数,即计算过程,并将输出返回
def forward(self, x):
x = x@self.w.t() + self.b
return x

这样就可以将我们自定义的Mylinear加入整个网络:

# 网络结构
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__() self.model = nn.Sequential(
#nn.Linear(784, 200),
Mylinear(784,200),
nn.BatchNorm1d(200, eps=1e-8),
nn.LeakyReLU(inplace=True),
#nn.Linear(200, 200),
Mylinear(200, 200),
nn.BatchNorm1d(200, eps=1e-8),
nn.LeakyReLU(inplace=True),
#nn.Linear(200, 10),
Mylinear(200,10),
nn.LeakyReLU(inplace=True)
)

我们可以看出,MLP网络实际上也是继承自Module,这就说明了,nn.Module实际上可以实现一个嵌套的结构,我们的整个网络就是由一个嵌套的树形结构组成的。例如:

# Mylinear继承Module
class Mylinear(nn.Module):
# 传入输入维度和输出维度
def __init__(self, in_d, out_d):
# 调用父类构造函数
super(Mylinear, self).__init__()
# 使用Parameter类将w和b封装,这样可以通过nn.Module直接管理,并提供给优化器优化
self.w = nn.Parameter(torch.randn(out_d, in_d))
self.b = nn.Parameter(torch.randn(out_d)) # 实现forward函数,该函数为默认执行的函数,即计算过程,并将输出返回
def forward(self, x):
x = x @ self.w.t() + self.b
return x # 将几个nn.Module组件综合成一个
class Mylayer(nn.Module):
def __init__(self, in_d, out_d):
super(Mylayer, self).__init__()
# 包含一个全连接层,一个BN层,一个Leaky Relu层
self.lin = Mylinear(in_d, out_d)
self.bn = nn.BatchNorm1d(out_d, eps=1e-8)
self.lrelu = nn.LeakyReLU(inplace=True) # 按顺序跑一遍3种网络,返回最终结果
def forward(self, x):
x = self.lin(x)
x = self.bn(x)
x = self.lrelu(x)
return x # 网络结构
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__() self.model = nn.Sequential(
Mylayer(784, 200),
Mylayer(200, 200),
# nn.Linear(200, 10),
Mylinear(200, 10),
nn.LeakyReLU(inplace=True)
)

上述代表表示的结构如下图所示:

其中所有的类都继承自nn.Module,从前往后是嵌套的关系。在上述代码中,真正做计算的是橙色部分1-8,而其他的都只是作为封装。其中nn.Sequential、nn.BatchNorm1d、nn.LeakyReLU是pytorch提供的类,Mylinear和Mylayer是我们自己封装的类。

二、实现一个常用类Flatten类

Flatten就是将2D的特征图压扁为1D的特征向量,用于全连接层的输入。

# Flatten继承Module
class Flatten(nn.Module):
# 构造函数,没有什么要做的
def __init__(self):
# 调用父类构造函数
super(Flatten, self).__init__() # 实现forward函数
def forward(self, input):
# 保存batch维度,后面的维度全部压平,例如输入是28*28的特征图,压平后为784的向量
return input.view(input.size(0), -1)

三、nn.Module类的作用

1.便于保存模型:

# 每隔N epoch保存一次模型
torch.save(net.state_dict(),'ckpt_n_epoch.mdl')
# 下次训练时可以直接导入接着训练
net.load_state_dict(torch.load('ckpt_n_epoch.mdl'))

2.方便切换train和val模式

### 不同模式对于某些层的操作时不同的,例如BN,dropout层等
# 切换到train模式
net.train()
# 切换到validation模式
net.eval()

3.方便将网络转移到GPU上

# 定义GPU设备
device = torch.device('cuda')
# 将网络转移到GPU,注意to函数返回的是net的引用(引用是不变的)
# 不同的是net中的参数都转移到GPU上去了
net.to(device) # 不同于参数直接转移,转移后的w2(在GPU上)和转移前的w(在CPU上)两者完全是不一样的
# 我们要使之在GPU上运行,则必须使用w2
#w2 = w.to(device)

4.方便查看各层参数

# 获取由每一层参数组成的列表
para_list = list(net.parameters())
# 获取一个(name,每层参数)的tuple组成的列表
para_named_list = list(net.named_parameters())
# 获取一个{'model.0.weight': 参数,'model.0.bias': 参数, 'model.1.weight': 参数}
para_named_dict = dict(net.named_parameters())

四、数据增强

torchvision提供了很方便的数据预处理工具,数据增强可以一次性搞定。

from torchvision import datasets, transforms

train_data_trans = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
# 水平翻转,50%执行
transforms.RandomHorizontalFlip(),
# 垂直翻转,50%执行
transforms.RandomVerticalFlip(),
# 随机旋转范围在正负15°之间,也可以写(-15,15)
transforms.RandomRotation(15),
# 旋转范围在90-270之间
#transforms.RandomRotation([90,270]),
# 将图片方缩放到指定大小
transforms.Resize([32,32]),
# 随机剪裁图片到指定大小
transforms.RandomCrop([28,28]), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))

如果pytorch没有提供需要的预处理类,我们可以参照源码仿造写一个自定义处理的类来进行处理。例如对图片添加白噪声,按通道变换颜色等等。

[深度学习] pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)的更多相关文章

  1. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  2. [深度学习] pytorch学习笔记(2)(梯度、梯度下降、凸函数、鞍点、激活函数、Loss函数、交叉熵、Mnist分类实现、GPU)

    一.梯度 导数是对某个自变量求导,得到一个标量. 偏微分是在多元函数中对某一个自变量求偏导(将其他自变量看成常数). 梯度指对所有自变量分别求偏导,然后组合成一个向量,所以梯度是向量,有方向和大小. ...

  3. [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)

    一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...

  4. [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)

    一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...

  5. [深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器(含多GPU训练CPU加载预测的使用方法)

    Learn From: Pytroch 官方Tutorials Pytorch 官方文档 环境:python3.6 CUDA10 pytorch1.3 vscode+jupyter扩展 #%% #%% ...

  6. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  7. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

  8. 【PyTorch深度学习】学习笔记之PyTorch与深度学习

    第1章 PyTorch与深度学习 深度学习的应用 接近人类水平的图像分类 接近人类水平的语音识别 机器翻译 自动驾驶汽车 Siri.Google语音和Alexa在最近几年更加准确 日本农民的黄瓜智能分 ...

  9. 深度学习Keras框架笔记之AutoEncoder类

    深度学习Keras框架笔记之AutoEncoder类使用笔记 keras.layers.core.AutoEncoder(encoder, decoder,output_reconstruction= ...

随机推荐

  1. C语言 俄罗斯方块的实现1 全局变量

    目录 全局变量 程序的模块化之MVC 关于俄罗斯方块的代码实现要点 使用数组表示背景和方块 方块表示及其初始化 要让游戏动起来 方块自动下落 全局变量 简而言之,定义在函数外的变量,就是全局变量. 所 ...

  2. 【图论好题】ABC #142 Task F Pure

    题目大意 给定一个 $N$ 个点 $M$ 条边的有向图 $G$,无重边.自环.找出图 $G$ 的一个导出子图(induced subgraph) $G'$,且 $G'$ 中的每个点的入度和出度都是 1 ...

  3. 双01字典树最小XOR(three arrays)--2019 Multi-University Training Contest 5(hdu杭电多校第5场)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6625 题意: 给你两串数 a串,b串,让你一一配对XOR使得新的 C 串字典序最小. 思路: 首先这边 ...

  4. 自定义字段的设计与实现(Java实用版)

    前言 自定义字段又叫做"开放模型",用户可以根据自已的需求,添加需要的字段,实现个性化定制. 使用自定义字段的目的,使用自定义字段解决哪些问题 如现有一套CRM系统,客户模块中客户 ...

  5. Linux之常用脚本

    1) #检查php Money 队列脚本是否启动 php_count=`ps -ef | grep Money | grep -v "grep" | wc -l` ];then e ...

  6. N1考试必备词汇

    相次ぐ あいつぐ 淡い あわい 合間 あいま 渋い しぶい 相俟つ あいまつ 慌てよう あわてよう 明るい あかるい 安易過ぎる 明らか あきらか 用心 ようじん 悪事 あくじ 案の定 あんのじょう ...

  7. LeetCode——回文链表

    题目 给定一个链表的头节点head,请判断该链表是否为回 文结构. 例如: 1->2->1,返回true. 1->2->2->1,返回true. 15->6-> ...

  8. ASP.NET Core WebAPI帮助页--Swagger简单使用1.0

    1.什么是Swagger? Swagger是一个规范且完整的框架,提供描述.生产.消费和可视化RESTful API,它是为了解决Web API生成有用文档和帮助页的问题.   2.为啥选用swagg ...

  9. MP4 ISO基础媒体文件格式术语

    术语.定义和缩略术语 box 由唯一类型标识符和长度定义的面向对象的构造块(注:在一些标准称为“atom") chunk(块) 一个track连续采样集合 container box 唯一目 ...

  10. golang 方法

    方法: 在函数声明时,在其名字之前放上一个变量,即是一个方法.这个附加的参数会将该函数附 加到这种类型上,即相当于为这种类型定义了一个独占的方法. package main import " ...