深度学习(一)之MNIST数据集分类
任务目标
对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率。(最终本文达到了\(99.36\%\))
使用的库的版本:
- python:3.8.12
- pytorch:1.5.1
代码地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手写数字识别
数据集介绍
MNIST数字数据集来自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges。
在torchvision中自带了关于MNIST的数据集。如果直接使用自带的数据集,能方便不少。关于具体使用,可参考:PyTorch初探MNIST数据集 - 知乎 (zhihu.com)
在Lecun的提供的MNIST数据集,有如下4个文件(images文件和labels文件):

training set包含了60000张手写数字图片,test set包含了10000张图片。在images文件和labels文件中,数据是使用二进制进行保存的。
图像文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):
第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;
第5-8个byte存的是number of images,即图像数量60000;
第9-12个byte存的是每张图片行数/高度,即28;
第13-16个byte存的是每张图片的列数/宽度,即28。
从第17个byte开始,每个byte存储一张图片中的一个像素点的值。
标签文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):
第1-4个byte存的是文件的magic number,对应的十进制大小是2049;
第5-8个byte存的是number of items,即label数量60000;
从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。
二进制文件的Python处理代码:
import numpy as np
def read_image(file_path):
"""读取MNIST图片
Args:
file_path (str): 图片文件位置
Returns:
list: 图片列表
"""
with open(file_path,'rb') as f:
file = f.read()
img_num = int.from_bytes(file[4:8],byteorder='big') #图片数量
img_h = int.from_bytes(file[8:12],byteorder='big') #图片h
img_w = int.from_bytes(file[12:16],byteorder='big') #图片w
img_data = []
file = file[16:]
data_len = img_h*img_w
for i in range(img_num):
data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
img_data.append(np.array(data).reshape(img_h,img_w))
return img_data
def read_label(file_path):
with open(file_path,'rb') as f:
file = f.read()
label_num = int.from_bytes(file[4:8],byteorder='big') #label的数量
file = file[8:]
label_data = []
for i in range(label_num):
label_data.append(file[i])
return label_data
train_img = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte")
# test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")
数据集部分数据如下所示:

数据集划分
在深度学习中,需要将trainset划分成训练集,验证集。最终使用测试集去验证模型的结果。
训练集:用来训练模型参数。
验证集:验证模型的状况和收敛情况。
测试集:验证模型结果。
形象上来说训练集就像是学生的课本,学生 根据课本里的内容来掌握知识,验证集就像是作业,通过作业可以知道 不同学生学习情况、进步的速度快慢,而最终的测试集就像是考试,考的题是平常都没有见过,考察学生举一反三的能力。
来源:训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
因此,需要将上文中的train_img,train_label进行划分,划分为训练集和验证集。这里使用sklearn中的train_test_split进行划分,训练集和测试集的比例为\(8:2\)。
from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)
网络结构
根据网络的权重,Netron生成的网络结构图如下,图中详细的介绍了每一层的结构参数。

网络结构的简洁图如下所示,网络一共由3层卷积层(每层卷积分别由Conv2d,BatchNorm2d,MaxPool2d和Dropout构成)和2个全连接层构成。

Pytorch代码如下:
class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2),
nn.Dropout(0.25)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
)
self.conv_3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
)
self.fc = nn.Sequential(
nn.Linear(512,128),
nn.Linear(128,10)
)
def forward(self,x): #x (3,28,28)
x = self.conv_1(x) #x (32,14,14)
x = self.conv_2(x) #x (64,7,7)
x = self.conv_3(x) #x (128,4,4)
x = x.view(x.size(0),-1)
x = self.fc(x)
return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)
训练集以及验证集结果
大概经过300个epoch训练,验证集便能够达到\(99.9\%\)以上的正确率。

训练集的Loss曲线:

测试集结果
测试集使用训练400个epoch之后的模型进行预测。其最终预测的正确率为:\(99.36 \%\)。实际上,大概300个epoch就能够在测试集达到\(99\%\)以上的正确率。
参考
- MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
- MNIST — Torchvision 0.12 documentation (pytorch.org)
- python处理MNIST数据集 - 简书 (jianshu.com)
- 训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
- sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
- Netron
深度学习(一)之MNIST数据集分类的更多相关文章
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习之 cnn 进行 CIFAR10 分类
深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- keras框架下的深度学习(二)二分类和多分类问题
本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...
- 深度学习笔记(一):logistic分类【转】
本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型
代码仓库: https://github.com/brandonlyg/cute-dl 目标 上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...
- Python深度学习案例1--电影评论分类(二分类问题)
我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...
随机推荐
- NSString 类介绍及用法
1.NSString常见方法 NSString是 Objective-C 中核心处理字符串的类之一 创建常量字符串,注意使用"@"符号. NSString *astring = @ ...
- k8s之Pod基础概念
1. 资源限制 Pod是kubernetes中最小的资源管理组件,Pod也是最小化运行容器化应用的资源对象.一个Pod代表着集群中运行的一个进程.kubernetes中其他大多数组件都是围绕着Pod来 ...
- LVS-DR群集
LVS-DR群集 目录 LVS-DR群集 一.LVS-DR的工作原理 1. LVS-DR数据包流向分析 2. IP包头及数据帧头信息的变化 3. DR模式的特点 4.LVS-DR中的ARP问题 (1) ...
- 一个好用的多方隐私求交算法库JasonCeng/MultipartyPSI-Pro
Github链接传送:JasonCeng/MultipartyPSI-Pro 大家好,我是阿创,这是我的第29篇原创文章. 今天是一篇纯技术性文章,希望对工程狮们有所帮助. 向大家推荐一个我最近改造的 ...
- 关于sys.path.append()
当我们导入一个模块时:import xxx,默认情况下python解析器会搜索当前目录.已安装的内置模块和第三方模块,搜索路径存放在sys模块的path中: >>> import ...
- 如何从0到1设计一个类Dubbo的RPC框架
之前分享了如何从0到1设计一个MQ消息队列,今天谈谈"如何从0到1设计一个Dubbo的RPC框架",重点考验: 你对RPC框架的底层原理掌握程度. 以及考验你的整体RPC框架系统设 ...
- MySQL windows下cmd安装操作
sh1.下载安装包,解压到指定目录 网址:https://dev.mysql.com/downloads/mysql/ 2.添加环境变量 右键点击计算机-属性-高级系统设置-环境变量: 将mysql ...
- Java 使用jcifs读写共享文件夹报错jcifs.smb.SmbException: Failed to connect: 0.0.0.0<00>/10.1.*.*
Q:使用jcifs读写Windows 10 共享文件夹中的文件报jcifs.smb.SmbException: Failed to connect: 0.0.0.0<00>/10.1.*. ...
- [文档]运维故障报告template
RCA的基本概念 根本原因分析技术(root cause analysis,RCA). IOWA州立大学质量管理学院认为,很多公司在设备发生故障后,都能够很快修复, 但难以发现故障的根本原因,所以此故 ...
- 年底获奖人太多?奖状可以用Smartbi电子表格这么做!
又到一年年终时,你的年终奖到手了吗?奖金没领到,发个奖状压压惊 今天给大家分享年终奖相关的年终奖状的批量套打功能,保证你的奖状及时到手! 示例说明 现有多个人员的奖励需要通知,需要生成可翻页的奖状.并 ...