任务目标

对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率。(最终本文达到了\(99.36\%\))

使用的库的版本:

  1. python:3.8.12
  2. pytorch:1.5.1

代码地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手写数字识别

数据集介绍

MNIST数字数据集来自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在torchvision中自带了关于MNIST的数据集。如果直接使用自带的数据集,能方便不少。关于具体使用,可参考:PyTorch初探MNIST数据集 - 知乎 (zhihu.com)

在Lecun的提供的MNIST数据集,有如下4个文件(images文件和labels文件):

training set包含了60000张手写数字图片,test set包含了10000张图片。在images文件和labels文件中,数据是使用二进制进行保存的。

图像文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;

  • 第5-8个byte存的是number of images,即图像数量60000;

  • 第9-12个byte存的是每张图片行数/高度,即28;

  • 第13-16个byte存的是每张图片的列数/宽度,即28。

  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

标签文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;

  • 第5-8个byte存的是number of items,即label数量60000;

  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

二进制文件的Python处理代码:

import numpy as np
def read_image(file_path):
"""读取MNIST图片 Args:
file_path (str): 图片文件位置 Returns:
list: 图片列表
"""
with open(file_path,'rb') as f:
file = f.read()
img_num = int.from_bytes(file[4:8],byteorder='big') #图片数量
img_h = int.from_bytes(file[8:12],byteorder='big') #图片h
img_w = int.from_bytes(file[12:16],byteorder='big') #图片w
img_data = []
file = file[16:]
data_len = img_h*img_w for i in range(img_num):
data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
img_data.append(np.array(data).reshape(img_h,img_w)) return img_data def read_label(file_path):
with open(file_path,'rb') as f:
file = f.read()
label_num = int.from_bytes(file[4:8],byteorder='big') #label的数量
file = file[8:]
label_data = []
for i in range(label_num):
label_data.append(file[i])
return label_data train_img = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte") # test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")

数据集部分数据如下所示:

数据集划分

在深度学习中,需要将trainset划分成训练集验证集。最终使用测试集去验证模型的结果。

训练集:用来训练模型参数。

验证集:验证模型的状况和收敛情况。

测试集:验证模型结果。

形象上来说训练集就像是学生的课本,学生 根据课本里的内容来掌握知识,验证集就像是作业,通过作业可以知道 不同学生学习情况、进步的速度快慢,而最终的测试集就像是考试,考的题是平常都没有见过,考察学生举一反三的能力。

来源:训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)

因此,需要将上文中的train_img,train_label进行划分,划分为训练集验证集。这里使用sklearn中的train_test_split进行划分,训练集和测试集的比例为\(8:2\)。

from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)

网络结构

根据网络的权重,Netron生成的网络结构图如下,图中详细的介绍了每一层的结构参数。

网络结构的简洁图如下所示,网络一共由3层卷积层(每层卷积分别由Conv2d,BatchNorm2d,MaxPool2d和Dropout构成)和2个全连接层构成。

Pytorch代码如下:

class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2),
nn.Dropout(0.25)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.conv_3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.fc = nn.Sequential(
nn.Linear(512,128),
nn.Linear(128,10)
) def forward(self,x): #x (3,28,28)
x = self.conv_1(x) #x (32,14,14)
x = self.conv_2(x) #x (64,7,7)
x = self.conv_3(x) #x (128,4,4)
x = x.view(x.size(0),-1) x = self.fc(x)
return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)

训练集以及验证集结果

大概经过300个epoch训练,验证集便能够达到\(99.9\%\)以上的正确率。

训练集的Loss曲线:

测试集结果

测试集使用训练400个epoch之后的模型进行预测。其最终预测的正确率为:\(99.36 \%\)。实际上,大概300个epoch就能够在测试集达到\(99\%\)以上的正确率。

参考

  1. MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
  2. MNIST — Torchvision 0.12 documentation (pytorch.org)
  3. python处理MNIST数据集 - 简书 (jianshu.com)
  4. 训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
  5. sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
  6. Netron

深度学习(一)之MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  3. 深度学习之 cnn 进行 CIFAR10 分类

    深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...

  4. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  5. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  6. 深度学习笔记(一):logistic分类【转】

    本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...

  7. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  8. 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型

    代码仓库: https://github.com/brandonlyg/cute-dl 目标         上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...

  9. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

随机推荐

  1. Squid代理服务器应用

    Squid代理服务器应用 目录 Squid代理服务器应用 一.Squid的脚本概念 1. Squid的作用 2. Web代理的工作机制 3. 代理服务器的概念 4. 代理服务器的作用 5. 代理的基本 ...

  2. 搭建golang开发环境(1.14之后版本)

    Go语言1.14版本之后推荐使用go modules管理依赖,也不再需要把代码写在GOPATH目录下. 下载地址 Go官网下载地址:https://golang.org/dl/ Go官方镜像站(推荐) ...

  3. axios请求配置

    全局配置示例(在js文件配置): axios.defaults.baseURL = 'https://api.example.com'; axios.defaults.headers.common[' ...

  4. MATLAB基础学习(2)

    function result=mysum(a,b)%创建函数以及外部接口 s=0; for i=a:b s=s+i; end result=s; disp(s); end Matlab中ones() ...

  5. linux13

    ansible-playbook实现MySQL的二进制部署 Ansible playbook实现apache批量部署,并对不同主机提供以各自IP地址为内容的index.html http的报文结构和状 ...

  6. ELK-EFK-v7.12.0日志平台部署

    ELK和EFK是什么 ELK和EFK是四个开源产品的组合: Elasticsearch 一个基于Lucene搜索引擎的NoSQL数据库 Logstatsh 一个日志管道工具,接受数据输入,执行数据转换 ...

  7. 天啦,从Mongo到ClickHouse我到底经历了什么?

    前言: 在实现前端监控系统的最初,使用了 Mongo 作为日志数据存储库.文档型存储,在日志字段扩展和收缩上都能非常方便.天生的 JSON 格式和 NodeJs 配合也非常贴合.就这样度过了几个月的蜜 ...

  8. 云原生 PostgreSQL 集群 - PGO:来自 Crunchy Data 的 Postgres Operator

    使用 PGO 在 Kubernetes 上运行 Cloud Native PostgreSQL:来自 Crunchy Data 的 Postgres Operator! Cloud Native Po ...

  9. java中的运算符介绍

    运算符&和&&的区别&运算符有两种用法:(1)按位与:(2)逻辑与. &&运算符是短路与运算.逻辑与跟短路与的差别是非常巨大的,虽然二者都要求运算符左右 ...

  10. vue从后台拿数据渲染页面图片

    <div class="list-content"> <div v-for="goods in goodsList" class=" ...