【小白学PyTorch】4 构建模型三要素与权重初始化
文章目录:
1 模型三要素
三要素其实很简单
- 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Module
- 在__init__(self)中设置好需要的组件,比如conv,pooling,Linear,BatchNorm等等
- 最后在forward(self,x)中用定义好的组件进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了
我们来看一个例子:
先看__init__(self)函数
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool1 = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.pool2 = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
第一行是初始化,往后定义了一系列组件。nn.Conv2d就是一般图片处理的卷积模块,然后池化层,全连接层等等。
定义完这些定义forward函数
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
x为模型的输入,第一行表示x经过conv1,然后经过激活函数relu,然后经过pool1操作
第三行表示对x进行reshape,为后面的全连接层做准备
至此,对一个模型的定义完毕,如何使用呢?
例如:
net = Net()
outputs = net(inputs)
其实net(inputs),就是类似于使用了net.forward(inputs)这个函数。
2 参数初始化
简单地说就是设定什么层用什么初始方法,初始化的方法会在torch.nn.init中
话不多说,看一个案例:
# 定义权值初始化
def initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
torch.nn.init.normal_(m.weight.data,0,0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
这段代码的基本流程就是,先从self.modules()中遍历每一层,然后判断更曾属于什么类型,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然后根据不同类型的层,设定不同的权值初始化方法,例如Xavier,kaiming,normal_等等。kaiming也是MSRA初始化,是何恺明大佬在微软亚洲研究院的时候,因此得名。
上面代码中用到了self.modules(),这个是什么东西呢?
# self.modules的源码
def modules(self):
for name,module in self.named_modules():
yield module
功能就是:能依次返回模型中的各层,yield是让一个函数可以像迭代器一样可以用for循环不断从里面遍历(可能说的不太明确)。
3 完整运行代码
我们用下面的例子来更深入的理解self.modules(),同时也把上面的内容都串起来(下面的代码块可以运行):
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():
print(m)
运行结果:
# 这个是print(net.modules())的输出
<generator object Module.modules at 0x0000023BDCA23258>
# 这个是第一次从net.modules()取出来的东西,是整个网络的结构
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
# 从net.modules()第二次开始取得东西就是每一层了
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)
其中呢,并不是每一层都有偏执bias的,有的卷积层可以设置成不要bias的,所以对于卷积网络参数的初始化,需要判断一下是否有bias,(不过我好像记得bias默认初始化为0?不确定,有知道的朋友可以交流)
torch.nn.init.xavier_normal(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
上面代码表示用xavier_normal方法对该层的weight初始化,并判断是否存在偏执bias,若存在,将bias初始化为0。
4 尺寸计算与参数计算
我们把上面的主函数部分改成:
net = Net()
net.initialize_weights()
layers = {}
for m in net.modules():
if isinstance(m,nn.Conv2d):
print(m)
break
这里的输出m就是:
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
这个卷积层,就是我们设置的第一个卷积层,含义就是:输入3通道,输出6通道,卷积核\(5\times 5\),步长1,padding=0.
【问题1:输入特征图和输出特征图的尺寸计算】
之前的文章也讲过这个了,
\(output = \frac{input+2\times padding -kernel}{stride}+1\)
用代码来验证一下这个公式:
net = Net()
net.initialize_weights()
input = torch.ones((16,3,10,10))
output = net.conv1(input)
print(input.shape)
print(output.shape)
初始结果:
torch.Size([16, 3, 10, 10])
torch.Size([16, 6, 6, 6])
第一个维度上batch,第二个是通道channel,第三个和第四个是图片(特征图)的尺寸。
\(\frac{10+2\times 0-5}{1}+1=6\) 算出来的结果没毛病。
【问题2:这个卷积层中有多少的参数?】
输入通道是3通道的,输出是6通道的,卷积核是\(5\times 5\)的,所以理解为6个\(3\times 5\times 5\)的卷积核,所以不考虑bias的话,参数量是\(3\times 5\times 5\times 6=450\),考虑bais的话,就每一个卷积核再增加一个偏置值。(这是一个一般人会忽略的知识点欸)
下面用代码来验证:
net = Net()
net.initialize_weights()
for m in net.modules():
if isinstance(m,nn.Conv2d):
print(m)
print(m.weight.shape)
print(m.bias.shape)
break
输出结果是:
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
torch.Size([6, 3, 5, 5])
torch.Size([6])
都和预料中一样。
【小白学PyTorch】4 构建模型三要素与权重初始化的更多相关文章
- 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)
文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...
- 【小白学PyTorch】20 TF2的eager模式与求导
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】18 TF2构建自定义模型
[机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...
- 【小白学PyTorch】15 TF2实现一个简单的服装分类任务
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
- 【小白学PyTorch】8 实战之MNIST小试牛刀
文章来自微信公众号[机器学习炼丹术].有什么问题都可以咨询作者WX:cyx645016617.想交个朋友占一个好友位也是可以的~好友位快满了不过. 参考目录: 目录 1 探索性数据分析 1.1 数据集 ...
- 【小白学PyTorch】17 TFrec文件的创建与读取
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 小白学phoneGap《构建跨平台APP:phoneGap移动应用实战》连载三(通过实例来体验生命周期)
4.1.2 通过实例来亲身体验Activity的生命周期 上一小节介绍了Activity生命周期中的各个过程,本小节将以一个简单的实例来使读者亲身体验到Activity生命周期中的各个事件. 在Ec ...
- 【小白学PyTorch】19 TF2模型的存储与载入
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
随机推荐
- electron开发 - 打印流程(仅支持6.0.0版本以上)
Electron打印 标签打印 标签打印一般有两种方式: 驱动打印,与普通打印机一样通过驱动方式打印. 通过指令打印,不同厂家的的打印机指令集不一样,可查看厂家提供的手册. electron 打印方式 ...
- Android RecyclerView的补充。
明天写吧.. 今天写,然后再写今天的内容,虽然结课了,我们还是得学习,所以如果我学习了一些知识,不出意外每天会持续更新的. RecyclerView其实是可以完全代替ListView的存在, 但是为啥 ...
- 为什么overflow:hidden能达到清除浮动的目的?
1. 什么是浮动 <精通CSS>(第3版)关于浮动的描述: 浮动盒子可以向左或向右移动,直到其外边沿接触包含块的外边沿,或接触另一个浮动盒子的外边沿. 浮动盒子也会脱离常规文档流,因此常规 ...
- ZooKeeper Watcher 机制
前言 在 ZooKeeper 中,客户端可以向服务端注册一个监听器,监听某个节点或者其子节点列表,当监听对象发生变化时,服务端就会向指定的客户端发送通知,这是 ZooKeeper 中的 Watcher ...
- 博客主题推荐——复杂&简单
首先感谢原作者cjunn提供的主题autm,以下配置都基于此主题设定.很多小伙伴喜欢现在的样式,分享如下.只需简单几步即可. 如果你想使用本博客主题样式,并希望能得到远程推送更新,只需查看 快速部署. ...
- 盘点 35 个 Apache 顶级项目,我拜服了…
Apache 软件基金会 Apache 软件基金会,全称:Apache Software Foundation,简称:ASF,成立于 1999 年 7 月,是目前世界上最大的最受欢迎的开源软件基金会, ...
- LeetCode 392. Is Subsequence 详解
题目详情 给定字符串 s 和 t ,判断 s 是否为 t 的子序列. 你可以认为 s 和 t 中仅包含英文小写字母.字符串 t 可能会很长(长度 ~= 500,000),而 s 是个短字符串(长度 & ...
- 后端排序时去掉element表格排序的null状态
经常会遇到远程排序,需要去掉null状态的排序,当设置sortable='custom'时,设置sort-orders为['ascending', 'descending']是不生效的.然后查到了一种 ...
- 攻防世界-web(进阶)-NaNNaNNaNNaN-Batman
用winhex打开,发现是一个javascript代码.将文件重命名为html文件,用浏览器打开. 打开是一个输入框,输入任何东西都梅反应,尝试弹框输入也无果,继续查看代码. 查看代码,可以看到最开始 ...
- 关于java中jdk的环境变量配置
关于java中jdk的环境变量配置 烦死人,在网上找了很长时间.最终找到了一个方法!现在将其总结帮助后来人. 方法/步骤 1 下载好jdk,并按照提示一步步安装,最后记下jdk所在的安装位置,这里 ...