从头学pytorch(十四):lenet
卷积神经网络
在之前的文章里,对28 X 28的图像,我们是通过把它展开为长度为784的一维向量,然后送进全连接层,训练出一个分类模型.这样做主要有两个问题
- 图像在同一列邻近的像素在这个向量中可能相距较远。它们构成的模式可能难以被模型识别。
- 对于大尺寸的输入图像,使用全连接层容易造成模型过大。假设输入是高和宽均为1000像素的彩色照片(含3个通道)。即使全连接层输出个数仍是256,该层权重参数的形状是\(3,000,000\times 256\),按照参数为float,占用4字节计算,它占用了大约3000000 X 256 X4bytes=3000000kb=3000M=3G的内存或显存。
很显然,通过使用卷积操作可以有效的改善这两个问题.关于卷积操作,池化操作等,参见置顶文章https://www.cnblogs.com/sdu20112013/p/10149529.html.
LENET
lenet是比较早期提出来的一个神经网络,其结构如下图所示.
LeNet的结构比较简单,就是2次重复的卷积激活池化后面接三个全连接层.卷积层的卷积核用的5 X 5,池化用的窗口大小为2 X 2,步幅为2.
对我们的输入(28 x 28)来说,卷积层得到的输出shape为[batch,16,4,4],在送入全连接层前,要reshape成[batch,16x4x4].可以理解为通过卷积,对没一个样本,我们
都提取出来了16x4x4=256个特征.这些特征用来识别图像里的空间模式,比如线条和物体局部.
全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。
net0 = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
batch_size=64
X = torch.randn((batch_size,1,28,28))
out=net0(X)
print(out.shape)
输出
torch.Size([64, 16, 4, 4])
这就是上面我们说的"对我们的输入(28 x 28)来说,卷积层得到的输出shape为[batch,16,4,4]"的由来.
模型定义
至此,我们可以给出LeNet的定义:
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(16*4*4, 120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1))
return output
在forward()中,在输入全连接层之前,要先feature.view(img.shape[0], -1)做一次reshape.
我们用gpu来做训练,所以要把net的参数都存储在显存上:
net = LeNet().cuda()
数据加载
import torch
from torch import nn
import sys
sys.path.append("..")
import learntorch_utils
batch_size,num_workers=64,4
train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers)
load_data定义于learntorch_utils.py,如下:
def load_data(batch_size,num_workers):
mnist_train = torchvision.datasets.FashionMNIST(root='/home/sc/disk/keepgoing/learn_pytorch/Datasets/FashionMNIST',
train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='/home/sc/disk/keepgoing/learn_pytorch/Datasets/FashionMNIST',
train=False, download=True,
transform=transforms.ToTensor())
train_iter = torch.utils.data.DataLoader(
mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(
mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_iter,test_iter
定义损失函数
l = nn.CrossEntropyLoss()
定义优化器
opt = torch.optim.Adam(net.parameters(),lr=0.01)
定义评估函数
def test():
acc_sum = 0
batch = 0
for X,y in test_iter:
X,y = X.cuda(),y.cuda()
y_hat = net(X)
acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
batch += 1
print('acc:%f' % (acc_sum/(batch*batch_size)))
训练
- 前向传播
- 计算loss
- 梯度清空,反向传播
- 更新参数
num_epochs=5
def train():
for epoch in range(num_epochs):
train_l_sum,batch=0,0
for X,y in train_iter:
X,y = X.cuda(),y.cuda() #把tensor放到显存
y_hat = net(X) #前向传播
loss = l(y_hat,y) #计算loss,nn.CrossEntropyLoss中会有softmax的操作
opt.zero_grad()#梯度清空
loss.backward()#反向传播,求出梯度
opt.step()#根据梯度,更新参数
train_l_sum += loss.item()
batch += 1
print('epoch %d,train_loss %f' % (epoch + 1,train_l_sum/(batch*batch_size)))
test()
输出如下:
epoch 1,train_loss 0.011750
acc:0.799064
epoch 2,train_loss 0.006442
acc:0.855195
epoch 3,train_loss 0.005401
acc:0.857584
epoch 4,train_loss 0.004946
acc:0.874602
epoch 5,train_loss 0.004631
acc:0.874403
从头学pytorch(十四):lenet的更多相关文章
- 从头学pytorch(十五):AlexNet
AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...
- 从头学pytorch(十九):批量归一化batch normalization
批量归一化 论文地址:https://arxiv.org/abs/1502.03167 批量归一化基本上是现在模型的标配了. 说实在的,到今天我也没搞明白batch normalize能够使得模型训练 ...
- 从头学pytorch(十二):模型保存和加载
模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...
- 从头学pytorch(十六):VGG NET
VGG AlexNet在Lenet的基础上增加了几个卷积层,改变了卷积核大小,每一层输出通道数目等,并且取得了很好的效果.但是并没有提出一个简单有效的思路. VGG做到了这一点,提出了可以通过重复使⽤ ...
- 从头学pytorch(十八):GoogLeNet
GoogLeNet GoogLeNet和vgg分别是2014的ImageNet挑战赛的冠亚军.GoogLeNet则做了更加大胆的网络结构尝试,虽然深度只有22层,但大小却比AlexNet和VGG小很多 ...
- HDU 6467 简单数学题 【递推公式 && O(1)优化乘法】(广东工业大学第十四届程序设计竞赛)
传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6467 简单数学题 Time Limit: 4000/2000 MS (Java/Others) M ...
- HDU 6464 免费送气球 【权值线段树】(广东工业大学第十四届程序设计竞赛)
传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6464 免费送气球 Time Limit: 2000/1000 MS (Java/Others) M ...
- HDU 6470 Count 【矩阵快速幂】(广东工业大学第十四届程序设计竞赛 )
题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6470 Count Time Limit: 6000/3000 MS (Java/Others) ...
- HDU 6467.简单数学题-数学题 (“字节跳动-文远知行杯”广东工业大学第十四届程序设计竞赛)
简单数学题 Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total Submi ...
随机推荐
- There is no getter for property named 'XXX' in 'class java.lang.String'
实验环境:spring boot+mybitis 由于采用的不带映射xml文件的模式,因此 方法1: 把#{xxx}修改为 #{_parameter} 即可 select count(*) from ...
- Python--day71--Cookie和Session
一.Cookie Cookie图示: 二.Session 引用:http://www.cnblogs.com/liwenzhou/p/8343243.html cookie Cookie的由来 大家都 ...
- poj 2993
跟poj 2996反过来了,这里比较麻烦的就是处理白棋和黑棋各棋子对应的位置 还有在最后打印棋盘式|,:,.的时候会有点繁琐(- - ACMer新手 ): 直接看代码吧: #include<cs ...
- 第三章 通过java SDK 实现个性化智能合约的部署与测试
想了解相关区块链开发,技术提问,请加QQ群:538327407 前提 已经部署好底层,外网可以正常请求访问. 正常流程 1.基础合约处理 https://fisco-bcos-documentatio ...
- P1047 汉诺塔
题目描述 汉诺塔是根据一个印度传说形成的数学问题:有三根杆子A, B, C, A杆上有n个穿孔圆盘, 盘的尺寸由下到上依次变小. 要求按照下列规则将所有圆盘移至C杆: 每次只能移动一个圆盘 大盘不能叠 ...
- 关于axios的一些封装
关于Axios的封装 为何需要在封装 应用场景,项目中涉及100个AJAX请求,其中: 1.其中60个需要在请求头header设置token headers: {token: token}用于权限校验 ...
- Cannot destructure property `createHash` of 'undefined' or 'null'(next服务端渲染引入next-less错误).
next中引入@zeit/next-less因next版本过低(webpack4之前的版本)无法执行next-less内置的mini-css-extract-plugin mini-css-extra ...
- [USACO10OCT]Lake Counting(DFS)
很水的DFS. 为什么放上来主要是为了让自己的博客有一道DFS题解,,, #include<bits/stdc++.h> using namespace std; ][],ans,flag ...
- API自动化测试指南
我相信自动化技能已经成为高级测试工程师总体技能的标配.敏捷和持续测试破坏了传统的测试自动化实践,导致测试工程师重新考虑自动化的完成方式.当今的自动化工程师需要在GUI的下方深入到API级别完成软件质量 ...
- 闲着没事,做个chrome浏览器插件,适合初学者
时光偷走的,永远都是我们眼皮底下看不见的珍贵. 本插件功能:替换掉网页中的指定图片的src地址. 使用插件前: 使用插件后: 鲜花(闲话):这个网站的不加水印的图片连接被保存在,图片的data-ima ...