任务目标

对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率。(最终本文达到了\(99.36\%\))

使用的库的版本:

  1. python:3.8.12
  2. pytorch:1.5.1

代码地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手写数字识别

数据集介绍

MNIST数字数据集来自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在torchvision中自带了关于MNIST的数据集。如果直接使用自带的数据集,能方便不少。关于具体使用,可参考:PyTorch初探MNIST数据集 - 知乎 (zhihu.com)

在Lecun的提供的MNIST数据集,有如下4个文件(images文件和labels文件):

training set包含了60000张手写数字图片,test set包含了10000张图片。在images文件和labels文件中,数据是使用二进制进行保存的。

图像文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;

  • 第5-8个byte存的是number of images,即图像数量60000;

  • 第9-12个byte存的是每张图片行数/高度,即28;

  • 第13-16个byte存的是每张图片的列数/宽度,即28。

  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

标签文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;

  • 第5-8个byte存的是number of items,即label数量60000;

  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

二进制文件的Python处理代码:

import numpy as np
def read_image(file_path):
"""读取MNIST图片 Args:
file_path (str): 图片文件位置 Returns:
list: 图片列表
"""
with open(file_path,'rb') as f:
file = f.read()
img_num = int.from_bytes(file[4:8],byteorder='big') #图片数量
img_h = int.from_bytes(file[8:12],byteorder='big') #图片h
img_w = int.from_bytes(file[12:16],byteorder='big') #图片w
img_data = []
file = file[16:]
data_len = img_h*img_w for i in range(img_num):
data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
img_data.append(np.array(data).reshape(img_h,img_w)) return img_data def read_label(file_path):
with open(file_path,'rb') as f:
file = f.read()
label_num = int.from_bytes(file[4:8],byteorder='big') #label的数量
file = file[8:]
label_data = []
for i in range(label_num):
label_data.append(file[i])
return label_data train_img = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte") # test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")

数据集部分数据如下所示:

数据集划分

在深度学习中,需要将trainset划分成训练集验证集。最终使用测试集去验证模型的结果。

训练集:用来训练模型参数。

验证集:验证模型的状况和收敛情况。

测试集:验证模型结果。

形象上来说训练集就像是学生的课本,学生 根据课本里的内容来掌握知识,验证集就像是作业,通过作业可以知道 不同学生学习情况、进步的速度快慢,而最终的测试集就像是考试,考的题是平常都没有见过,考察学生举一反三的能力。

来源:训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)

因此,需要将上文中的train_img,train_label进行划分,划分为训练集验证集。这里使用sklearn中的train_test_split进行划分,训练集和测试集的比例为\(8:2\)。

from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)

网络结构

根据网络的权重,Netron生成的网络结构图如下,图中详细的介绍了每一层的结构参数。

网络结构的简洁图如下所示,网络一共由3层卷积层(每层卷积分别由Conv2d,BatchNorm2d,MaxPool2d和Dropout构成)和2个全连接层构成。

Pytorch代码如下:

class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2),
nn.Dropout(0.25)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.conv_3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.fc = nn.Sequential(
nn.Linear(512,128),
nn.Linear(128,10)
) def forward(self,x): #x (3,28,28)
x = self.conv_1(x) #x (32,14,14)
x = self.conv_2(x) #x (64,7,7)
x = self.conv_3(x) #x (128,4,4)
x = x.view(x.size(0),-1) x = self.fc(x)
return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)

训练集以及验证集结果

大概经过300个epoch训练,验证集便能够达到\(99.9\%\)以上的正确率。

训练集的Loss曲线:

测试集结果

测试集使用训练400个epoch之后的模型进行预测。其最终预测的正确率为:\(99.36 \%\)。实际上,大概300个epoch就能够在测试集达到\(99\%\)以上的正确率。

参考

  1. MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
  2. MNIST — Torchvision 0.12 documentation (pytorch.org)
  3. python处理MNIST数据集 - 简书 (jianshu.com)
  4. 训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
  5. sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
  6. Netron

深度学习(一)之MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  3. 深度学习之 cnn 进行 CIFAR10 分类

    深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...

  4. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  5. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  6. 深度学习笔记(一):logistic分类【转】

    本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...

  7. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  8. 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型

    代码仓库: https://github.com/brandonlyg/cute-dl 目标         上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...

  9. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

随机推荐

  1. laravel中closure和curry 科里化函数式编程

    推荐值得的一看博客文档:谢谢作者  : https://my.oschina.net/zhmsong 函数式编程curry的概念: 只传递给函数一部分参数来调用函数,然后返回一个函数去处理剩下的参数. ...

  2. Javascript疑问【长期更新】

    1.插入 Javacript 的正确位置是? 答:<body> 部分和 <head> 部分均可. 2.外部脚本必须包含 <script> 标签吗? 答:外部脚本不能 ...

  3. Java基础复习(四)

    1.Integer与int的区别 int是java提供的8种原始数据类型之一.Java为每个原始类型提供了封装类,Integer是java为int提供的封装类.int的默认值为0,而Integer的默 ...

  4. Docker Explore the application

    https://docs.docker.com/docker-for-mac/#explore-the-application   Open a command-line terminal and t ...

  5. 基于UDP传输协议局域网文件接收软件设计 Java版

    网路传输主要的两大协议为TCP/IP协议和UDP协议,本文主要介绍基于UDP传输的一个小软件分享,针对于Java网络初学者是一个很好的练笔,大家可以参考进行相关的联系,但愿能够帮助到大家. 话不多说, ...

  6. redis(一)-----初识redis

    Redis是一种基于键值对(key-value)的NoSQL数据库 因为Redis会将所有数据都存放在内存 中,所以它的读写性能非常惊人.不仅如此,Redis还可以将内存的数据利 用快照和日志的形式保 ...

  7. matlab文件处理

    1.读取文件(按行读取) fid = open('file_name');while(~feof(fid)) line = fgetl(fid); % 读取一行数据 endfid.close(); 2 ...

  8. Aluminum: An Asynchronous, GPU-Aware Communication Library Optimized for Large-Scale Training of Deep Neural Networks on HPC Systems

    本文发表在MLHPC 2018上,主要介绍了一个名为Aluminum通信库,这个库针对Allreduce做了一些关于计算通信重叠以及针对延迟的优化,以加速分布式深度学习训练过程. 分布式训练的通信需求 ...

  9. vue前端渲染和thymeleaf模板渲染冲突问题

    vue前端渲染和thymeleaf模板渲染冲突问题 话不多说直接上现象: 解决办法: 在此做个记录吧,说不定以后会碰到 <<QIUQIU&LL>>

  10. 【性能测试实战:jmeter+k8s+微服务+skywalking+efk】系列之:性能测试场景设计

    说明: 本文是基于虚拟机环境配置设计的 性能测试需求 总tps≥100 每个业务的rt<500ms 持续稳定跑50万业务量 单场景 目的:找到单场景的性能问题,为容量场景提供参考,如果低于容量场 ...