任务目标

对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率。(最终本文达到了\(99.36\%\))

使用的库的版本:

  1. python:3.8.12
  2. pytorch:1.5.1

代码地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手写数字识别

数据集介绍

MNIST数字数据集来自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在torchvision中自带了关于MNIST的数据集。如果直接使用自带的数据集,能方便不少。关于具体使用,可参考:PyTorch初探MNIST数据集 - 知乎 (zhihu.com)

在Lecun的提供的MNIST数据集,有如下4个文件(images文件和labels文件):

training set包含了60000张手写数字图片,test set包含了10000张图片。在images文件和labels文件中,数据是使用二进制进行保存的。

图像文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;

  • 第5-8个byte存的是number of images,即图像数量60000;

  • 第9-12个byte存的是每张图片行数/高度,即28;

  • 第13-16个byte存的是每张图片的列数/宽度,即28。

  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

标签文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;

  • 第5-8个byte存的是number of items,即label数量60000;

  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

二进制文件的Python处理代码:

import numpy as np
def read_image(file_path):
"""读取MNIST图片 Args:
file_path (str): 图片文件位置 Returns:
list: 图片列表
"""
with open(file_path,'rb') as f:
file = f.read()
img_num = int.from_bytes(file[4:8],byteorder='big') #图片数量
img_h = int.from_bytes(file[8:12],byteorder='big') #图片h
img_w = int.from_bytes(file[12:16],byteorder='big') #图片w
img_data = []
file = file[16:]
data_len = img_h*img_w for i in range(img_num):
data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
img_data.append(np.array(data).reshape(img_h,img_w)) return img_data def read_label(file_path):
with open(file_path,'rb') as f:
file = f.read()
label_num = int.from_bytes(file[4:8],byteorder='big') #label的数量
file = file[8:]
label_data = []
for i in range(label_num):
label_data.append(file[i])
return label_data train_img = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte") # test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")

数据集部分数据如下所示:

数据集划分

在深度学习中,需要将trainset划分成训练集验证集。最终使用测试集去验证模型的结果。

训练集:用来训练模型参数。

验证集:验证模型的状况和收敛情况。

测试集:验证模型结果。

形象上来说训练集就像是学生的课本,学生 根据课本里的内容来掌握知识,验证集就像是作业,通过作业可以知道 不同学生学习情况、进步的速度快慢,而最终的测试集就像是考试,考的题是平常都没有见过,考察学生举一反三的能力。

来源:训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)

因此,需要将上文中的train_img,train_label进行划分,划分为训练集验证集。这里使用sklearn中的train_test_split进行划分,训练集和测试集的比例为\(8:2\)。

from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)

网络结构

根据网络的权重,Netron生成的网络结构图如下,图中详细的介绍了每一层的结构参数。

网络结构的简洁图如下所示,网络一共由3层卷积层(每层卷积分别由Conv2d,BatchNorm2d,MaxPool2d和Dropout构成)和2个全连接层构成。

Pytorch代码如下:

class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2),
nn.Dropout(0.25)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.conv_3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.fc = nn.Sequential(
nn.Linear(512,128),
nn.Linear(128,10)
) def forward(self,x): #x (3,28,28)
x = self.conv_1(x) #x (32,14,14)
x = self.conv_2(x) #x (64,7,7)
x = self.conv_3(x) #x (128,4,4)
x = x.view(x.size(0),-1) x = self.fc(x)
return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)

训练集以及验证集结果

大概经过300个epoch训练,验证集便能够达到\(99.9\%\)以上的正确率。

训练集的Loss曲线:

测试集结果

测试集使用训练400个epoch之后的模型进行预测。其最终预测的正确率为:\(99.36 \%\)。实际上,大概300个epoch就能够在测试集达到\(99\%\)以上的正确率。

参考

  1. MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
  2. MNIST — Torchvision 0.12 documentation (pytorch.org)
  3. python处理MNIST数据集 - 简书 (jianshu.com)
  4. 训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
  5. sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
  6. Netron

深度学习(一)之MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  3. 深度学习之 cnn 进行 CIFAR10 分类

    深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...

  4. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  5. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  6. 深度学习笔记(一):logistic分类【转】

    本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...

  7. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  8. 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型

    代码仓库: https://github.com/brandonlyg/cute-dl 目标         上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...

  9. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

随机推荐

  1. NSString 类介绍及用法

    1.NSString常见方法 NSString是 Objective-C 中核心处理字符串的类之一 创建常量字符串,注意使用"@"符号. NSString *astring = @ ...

  2. k8s之Pod基础概念

    1. 资源限制 Pod是kubernetes中最小的资源管理组件,Pod也是最小化运行容器化应用的资源对象.一个Pod代表着集群中运行的一个进程.kubernetes中其他大多数组件都是围绕着Pod来 ...

  3. LVS-DR群集

    LVS-DR群集 目录 LVS-DR群集 一.LVS-DR的工作原理 1. LVS-DR数据包流向分析 2. IP包头及数据帧头信息的变化 3. DR模式的特点 4.LVS-DR中的ARP问题 (1) ...

  4. 一个好用的多方隐私求交算法库JasonCeng/MultipartyPSI-Pro

    Github链接传送:JasonCeng/MultipartyPSI-Pro 大家好,我是阿创,这是我的第29篇原创文章. 今天是一篇纯技术性文章,希望对工程狮们有所帮助. 向大家推荐一个我最近改造的 ...

  5. 关于sys.path.append()

    当我们导入一个模块时:import  xxx,默认情况下python解析器会搜索当前目录.已安装的内置模块和第三方模块,搜索路径存放在sys模块的path中: >>> import  ...

  6. 如何从0到1设计一个类Dubbo的RPC框架

    之前分享了如何从0到1设计一个MQ消息队列,今天谈谈"如何从0到1设计一个Dubbo的RPC框架",重点考验: 你对RPC框架的底层原理掌握程度. 以及考验你的整体RPC框架系统设 ...

  7. MySQL windows下cmd安装操作

    sh1.下载安装包,解压到指定目录  网址:https://dev.mysql.com/downloads/mysql/ 2.添加环境变量 右键点击计算机-属性-高级系统设置-环境变量: 将mysql ...

  8. Java 使用jcifs读写共享文件夹报错jcifs.smb.SmbException: Failed to connect: 0.0.0.0<00>/10.1.*.*

    Q:使用jcifs读写Windows 10 共享文件夹中的文件报jcifs.smb.SmbException: Failed to connect: 0.0.0.0<00>/10.1.*. ...

  9. [文档]运维故障报告template

    RCA的基本概念 根本原因分析技术(root cause analysis,RCA). IOWA州立大学质量管理学院认为,很多公司在设备发生故障后,都能够很快修复, 但难以发现故障的根本原因,所以此故 ...

  10. 年底获奖人太多?奖状可以用Smartbi电子表格这么做!

    又到一年年终时,你的年终奖到手了吗?奖金没领到,发个奖状压压惊 今天给大家分享年终奖相关的年终奖状的批量套打功能,保证你的奖状及时到手! 示例说明 现有多个人员的奖励需要通知,需要生成可翻页的奖状.并 ...