1. 准备数据集

1.1 MNIST数据集获取:

  • torchvision.datasets接口直接下载,该接口可以直接构建数据集,推荐

  • 其他途径下载后,编写程序进行读取,然后由Datasets构建自己的数据集

​ ​ 本文使用第一种方法获取数据集,并使用Dataloader进行按批装载。如果使用程序下载失败,请将其他途径下载的MNIST数据集 [文件][解压文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件结构图如下:

​ ​ 其中,model文件夹用来存储每个epoch训练的模型参数,根文件夹下包含model.py用于训练模型,test.py为测试集测试,show.py为展示部分

1.2 程序部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time # 1. 准备数据集
## 1.1 使用torchvision自动下载MNIST数据集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True) ## 1.2 构建数据集装载器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("===============数据统计===============")
print("训练集样本:",train_data.__len__(), train_data.data.shape)

​ ​ 【代码解析】

  • root为存放MNIST的路径,trian=True代表下载的为训练集和训练集标签,False则代表测试集和标签

  • transforms.ToTensor()表示将shape为(H, W, C)的 numpy 数组或 img 转为shape为(C, H, W)的tensor,并将数值归一化为[0,1]

  • download为True则代表自动下载,若该文件夹下已经下载,则直接跳过下载步骤

  • shuffle=True,表示对分好的batch进行洗牌操作,drop_last=True表示对最后不足batch大小的剩余样本舍去,False表示保留

  • num_works表示每次读取的进程数,和核心数有关

​ ​ Dataset和Dataloader详细说明,请移步:[Pytorch Dataset和Dataloader 学习笔记(二)]

2. 设计网络结构

2.1 网络设计

​ ​ 网络结构如上图所示,输入图像—>卷积1—>池化1—>卷积2—>池化2—>全连接1—>全连接2—>softmax,每次卷积通道数都增加一倍,最后送入全连接层实现分类

2.2 程序部分

# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10) def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss会自动激活最后一层的输出以及softmax处理
return y_hat net = Net() # 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

​ ​ 【代码解析】

  • fc1的1568维度是因为最后一次池化后的shape为32*7*7=1568

  • 在最后一层,并没有进行relu激活以及接入softmax,是因为,在CrossEntropyLoss中会自动激活最后一层的输出以及softmax处理

​ ​ CrossEntropyLoss图参考:《PyTorch深度学习实践》完结合集

​ ​ 详细网络结构搭建说明,请移步:Pytorch线性规划模型 学习笔记(一)

3. 迭代训练

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5) # 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 统计当前epoch下的正确个数 loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,) print("we are done!")

​ ​ 【代码解析】

  • total_correct变量用于统计每个epoch下正确预测值的个数,每进行epoch进行一次清零
  • torch.argmax(y_hat, dim=1)用于选取y_hat下每一行的最大值(每个样本的最高得分),并返回与y相同维度的tensor
  • torch.eq(y_pre, y)用于比较两个矩阵元素是否相同,相同则返回True,不同则返回False,用于判断预测值与真实值是否相同
  • torch.save保存了每个epoch的网络权重参数

4. 测试集预测部分

# 测试模型,测试集为test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("---------------预测分析---------------")
print("测试集样本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval() total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)

​ ​ 经过20个epoch的训练,在测试集上达到了98.590%的准确率,部分batch真实值与预测值展示如下:


5. 全部代码

链接:链接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw

提取码:82l4

转载请说明出处

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)的更多相关文章

  1. pytorch CNN 手写数字识别

    一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...

  2. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  3. muduo网络库学习笔记(三)TimerQueue定时器队列

    目录 muduo网络库学习笔记(三)TimerQueue定时器队列 Linux中的时间函数 timerfd简单使用介绍 timerfd示例 muduo中对timerfd的封装 TimerQueue的结 ...

  4. Python 工匠:使用数字与字符串的技巧学习笔记

    #Python 工匠:使用数字与字符串的技巧学习笔记#https://github.com/piglei/one-python-craftsman/blob/master/zh_CN/3-tips-o ...

  5. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  6. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  8. TensorFlow学习笔记(三)MNIST数字识别问题

    一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...

  9. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. Mysql连接查询示例语句

    SELECT *FROM ssm_emp; SELECT * FROM ssm_dept; #查询两表交集 SELECT * FROM ssm_emp e INNER JOIN ssm_dept d ...

  2. vscode 终端操作命令npm报错

    错误: 如果没有安装的node.js ,则需要安装. node.js官网下载地址: https://nodejs.org/zh-cn/ 安装node.js 后会看到C:\Users\XXX\AppDa ...

  3. 【原创】JVM如何运行Java程序的?

    [Deerhang] 我们知道Java程序的运行是依赖于JVM虚拟机的,JVM类语言经过编译生成class字节码文件,字节码又经JVM进一步的编译生成机器码,最终运行在硬件上.那么JVM存在的意义是什 ...

  4. Windows进程间通讯(IPC)----管道

    管道的分类 管道其实际就是一段共享内存,只不过Windows规定需要使用I/O的形式类访问这块共享内存,管道可以分为匿名管道和命名管道. 匿名管道就是没有名字的管道,其支持单向传输数据,如果需要双向传 ...

  5. linux下符号链接和硬链接的区别

    存在2众不同类型的链接,软链接和硬链接,修改其中一个,硬链接指向的是节点(inode),软链接指向的是路径(path) 软连接文件 软连接文件也叫符号连接,这个文件包含了另一个文件的路径名,类似于wi ...

  6. 关于有符号数和无符号数的转换 - C/C++

    转载自:http://www.94cto.com/index/Article/content/id/59973.html 1.引例: 今天在做了一道关于有符号数和无符号数相互转换及其左移/右移的问题, ...

  7. BUA软件工程个人博客作业

    写在前面 项目 内容 所属课程 2020春季计算机学院软件工程(罗杰 任健) (北航) 作业要求 个人博客作业 课程目标 培养软件开发能力 本作业对实现目标的具体作用 阅读教材,了解软件工程,并比较各 ...

  8. Markdown使用概述

    Markdown使用概述 序言 作为一名编程学习的爱好者和初学者,由于学习编程的过程中总是存在遗忘以及很难动手写起来的问题,所以在看了许多关于编程学习方法的文章之后,选择使用typora作为我的笔记工 ...

  9. ruby基础(二)

    ruby语法基础 1.方法 方法时对象定义的与该对象相关的操作.在Ruby中,对象的所有的操作都被封装成 方法. 语法糖:语法糖是一种为了照顾一般人的习惯而产生的特殊语法. ruby中一切数据都是对象 ...

  10. 桌面支持qt版本是多少检查

    桌面支持qt版本是多少 # rpm -qa |grep qt |grep 3 |sortqt3-3.3.8b-60.nd7.2.x86_64qt-4.8.6-13.nd7.3.x86_64qt5-qt ...