在上一篇博客CNN核心概念理解中,我们以LeNet为例介绍了CNN的重要概念。在这篇博客中,我们将利用著名深度学习框架PyTorch实现LeNet5,并且利用它实现手写体字母的识别。训练数据采用经典的MNIST数据集。本文主要分为两个部分,一是如何使用PyTorch实现LeNet模型,二是实现数据准备、定义网络、定义损失函数、训练、测试等完整流程。

一、LeNet模型定义

LeNet是识别手写字母的经典网络,虽然年代久远,但从学习的角度仍不失为一个优秀的范例。要实现这个网络,首先来看看这个网络的结构:

这是一个简单的前向传播的网络,它接受32x32图片作为输入,经过卷积、池化和全连接层的计算,最终给出输出结果。实现的过程并不复杂:

 from torch import nn
from torch.nn import functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = F.relu(self.conv1(x))
# Max pooling over a (2, 2) window
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2) x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

我们继承了nn.Module模块,在__init__中完成了卷积层和全连接层的初始化。值得注意的是由于池化层没有参数,因此并没有一起初始化。初始化参数包括输入个数、输出个数,卷积层的参数还有卷积核大小。除此之外在第一个卷积层C1中还定义了padding,这是因为数据集中图片是28x28的,padding=2表明输入的时候在图片四周各填充2个像素的空白,将输入变成了32x32。

在forward中我们实现了前向传播。这里我们根据定义对输入依次进行卷积、激活、池化等操作,最后返回计算结果。在全连接层之前,有一个对数据的展开操作,我们使用Tensor的view函数实现,这个函数可以将Tensor转变成任意合法的形状。我们只定义了forward函数,而没有定义backword函数,这是因为PyTorch的自动微分功能自动帮我们完成了反向传播的定义。

LeNet模型这样就定义完成了。但是需要注意的是,这个网络和最初LeCun论文中的实现略有不同:

  • 原始论文中C3与S2并不是全连接而是部分连接,这样能减少部分计算量。而现代CNN模型中,比如AlexNet,ResNet等,都采取全连接的方式了。我们的实现在这里做了一些简化。
  • 原文中使用双曲正切作为激活函数,而我们使用了收敛速度更快的ReLu函数。
  • 按照原文描述,网络最后一层为高斯连接层。而我们为了简单起见还是用了全连接层。

LeNet其实是一个比较“古老”的模型了,我们不必追求完美的复现,理解其中关键的概念即可。

二、准备数据

为PyTorch准备数据非常方便。对于一些经典数据集,PyTorch已经将它们封装好了,我们可以直接拿来用。当然MNIST数据集也在此列,但是我们仍然定义了自己的数据集,因为这种方法可以处理更通用的情况。为了定义自己的数据集,首先要继承torch.utils.data.database类,然后实现至少__getitem__和__len__两个方法。

 import gzip, struct
import numpy as np
import torch.utils.data as data class MnistDataset(data.Dataset):
def __init__(self, path, train=True):
self.path = path
if train:
X, y = self._read('train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz')
else:
X, y = self._read('t10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz') self.images = torch.from_numpy(X.reshape(-1, 1, 28, 28)).float()
self.labels = torch.from_numpy(y.astype(int)) def __getitem__(self, index):
return self.images[index], self.labels[index] def __len__(self):
return len(self.images) def _read(self, image, label):
with gzip.open(self.path + image, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
X = np.frombuffer(fimg.read(), dtype=np.uint8).reshape(-1, rows, cols)
with gzip.open(self.path + label) as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
y = np.frombuffer(flbl.read(), dtype=np.int8)
return X, y

由于官网上提供的MNIST数据集是gzip压缩格式,因此我们在读取的时候首先要解压,然后转成numpy形式,最后转成Tensor保存起来。之后在__getitem__中返回相应的数据和类别就可以了,__len__函数直接返回数据集的大小。由于MNIST数据集有训练和测试两部分,因此需要分类处理。

三、使用数据训练网络

我们首先用DataLoader类加载数据集,DataLoader负责将数据转化成适当的形式放入模型训练。使用DataLoader可以方便地控制微批次大小、线程数等参数。

 train_dataset = MnistDataset('./data/')
train_loader = data.DataLoader(train_dataset, shuffle=True, batch_size=256,
num_workers=4)

这时候可以测试数据有没有成功加载进来,如图所示。

下一步定义评价函数和优化器,这一步很重要,但不是本文重点。直接给出代码:

 criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.99))

最后的给出训练过程的简化版。这个两层循环就是实际的训练过程,外层循环控制遍历数据集的次数,内层循环控制每一次参数更新。

 for epoch in range(5):
for (inputs, label) in train_loader:
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
output = net(inputs)
loss = criterion(output, label)
loss.backward()
optimizer.step()

三、模型评估

模型经过训练之后,将测试集输入放入模型,将输出和标签比对可以计算出模型的准确率等信息,进而对模型不断优化。此外如果想要了解模型到底学到了什么东西,还可以将中间层结果输出。如图所示:

这部分代码没有给出,完整代码可以到Github页面查看。

PyTorch实战:经典模型LeNet5实现手写体识别的更多相关文章

  1. 【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用

    一.前述 VGG16是由16层神经网络构成的经典模型,包括多层卷积,多层全连接层,一般我们改写的时候卷积层基本不动,全连接层从后面几层依次向前改写,因为先改参数较小的. 二.具体 1.因为本文中代码需 ...

  2. 经典网络LeNet5看卷积神经网络各层的维度变化

    本文介绍以下几个CNN经典模型:Lenet(1986年).Alexnet(2012年).GoogleNet(2014年).VGG(2014年).Deep Residual Learning(2015年 ...

  3. 入门项目数字手写体识别:使用Keras完成CNN模型搭建(重要)

    摘要: 本文是通过Keras实现深度学习入门项目——数字手写体识别,整个流程介绍比较详细,适合初学者上手实践. 对于图像分类任务而言,卷积神经网络(CNN)是目前最优的网络结构,没有之一.在面部识别. ...

  4. 大话CNN经典模型:LeNet

        近几年来,卷积神经网络(Convolutional Neural Networks,简称CNN)在图像识别中取得了非常成功的应用,成为深度学习的一大亮点.CNN发展至今,已经有很多变种,其中有 ...

  5. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  6. libsvm Minist Hog 手写体识别

    统计手写数字集的HOG特征 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ 这篇文章是模式识别的小作业,利用sv ...

  7. Scala 深入浅出实战经典 第66讲:Scala并发编程实战初体验

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  8. Scala 深入浅出实战经典 第57讲:Scala中Dependency Injection实战详解

    王家林亲授<DT大数据梦工厂>大数据实战视频 Scala 深入浅出实战经典(1-87讲)完整视频.PPT.代码下载:百度云盘:http://pan.baidu.com/s/1c0noOt6 ...

  9. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

随机推荐

  1. Nginx配置详解 http://www.cnblogs.com/knowledgesea/p/5175711.html

    Nginx配置详解 序言 Nginx是lgor Sysoev为俄罗斯访问量第二的rambler.ru站点设计开发的.从2004年发布至今,凭借开源的力量,已经接近成熟与完善. Nginx功能丰富,可作 ...

  2. css限制文字显示字数长度,超出部分自动用省略号显示,防止溢出到第二行

    为了保证页面的整洁美观,在很多的时候,我们常需要隐藏超出长度的文字.这在列表条目,题目,名称等地方常用到. 效果如下: 未限制显示长度,如果超出了会溢出到第二行里.严重影响用户体验和显示效果. 我们在 ...

  3. StringGrid换行功能

    关闭stringgrid的defaultdrawing功能 StringGrid1.Cells[cCol,cRow] := '测试1'+#13#10+'测试2'; procedure TForm1.S ...

  4. Linux下四种安装软件方式

    1.yum源安装 可以解决依赖关系,但不确定安装的位置 2.rpm 基础安装 要自己解决依赖问题 rpm -ivh 安装 rpm -uvh 更新 rpm -e --nodeps 卸载    取消依赖 ...

  5. 吴裕雄--天生自然JAVA面向对象高级编程学习笔记:Object类

    class Demo{ // 定义Demo类,实际上就是继承了Object类 }; public class ObjectDemo01{ public static void main(String ...

  6. gitlab clone或者pull 仓库

    今天在学git操作,想从gitlab上面clone下来并操作一下,但是一直出现 没有权限的错误,一直搞不定 后来才知道,需要ssh密钥才可以 ssh-keygen -t rsa -C "ex ...

  7. java多线程(待完善)

    1.小型系统 // 线程完成的任务(Runnable对象)和线程对象(Thread)之间紧密相连 class A implements Runnable{ public void run(){ // ...

  8. LR_问题_平均响应时间解释,summary与analysis不一致----Summary Report中的时间说明

    Summary是按整个场景的时间来做平均的,最大最小值,也是从整个场景中取出来的. (1)       平均响应时间:事物全部响应时间做平均计算 (2)       90%响应时间:将事物全部响应时间 ...

  9. HiBench成长笔记——(9) 分析源码monitor.py

    monitor.py 是主监控程序,将监控数据写入日志,并统计监控数据生成HTML统计展示页面: #!/usr/bin/env python2 # Licensed to the Apache Sof ...

  10. Unable to execute dex:Multuple dex files define 解决方法

    困扰我两天的问题终于解决了,在网上查的方法无非有三种 一. Eclipse->Project->去掉Build Automatically->Clear ->Build Pro ...