1. 准备数据集

1.1 MNIST数据集获取:

  • torchvision.datasets接口直接下载,该接口可以直接构建数据集,推荐

  • 其他途径下载后,编写程序进行读取,然后由Datasets构建自己的数据集

​ ​ 本文使用第一种方法获取数据集,并使用Dataloader进行按批装载。如果使用程序下载失败,请将其他途径下载的MNIST数据集 [文件][解压文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件结构图如下:

​ ​ 其中,model文件夹用来存储每个epoch训练的模型参数,根文件夹下包含model.py用于训练模型,test.py为测试集测试,show.py为展示部分

1.2 程序部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time # 1. 准备数据集
## 1.1 使用torchvision自动下载MNIST数据集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True) ## 1.2 构建数据集装载器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("===============数据统计===============")
print("训练集样本:",train_data.__len__(), train_data.data.shape)

​ ​ 【代码解析】

  • root为存放MNIST的路径,trian=True代表下载的为训练集和训练集标签,False则代表测试集和标签

  • transforms.ToTensor()表示将shape为(H, W, C)的 numpy 数组或 img 转为shape为(C, H, W)的tensor,并将数值归一化为[0,1]

  • download为True则代表自动下载,若该文件夹下已经下载,则直接跳过下载步骤

  • shuffle=True,表示对分好的batch进行洗牌操作,drop_last=True表示对最后不足batch大小的剩余样本舍去,False表示保留

  • num_works表示每次读取的进程数,和核心数有关

​ ​ Dataset和Dataloader详细说明,请移步:[Pytorch Dataset和Dataloader 学习笔记(二)]

2. 设计网络结构

2.1 网络设计

​ ​ 网络结构如上图所示,输入图像—>卷积1—>池化1—>卷积2—>池化2—>全连接1—>全连接2—>softmax,每次卷积通道数都增加一倍,最后送入全连接层实现分类

2.2 程序部分

# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10) def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss会自动激活最后一层的输出以及softmax处理
return y_hat net = Net() # 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

​ ​ 【代码解析】

  • fc1的1568维度是因为最后一次池化后的shape为32*7*7=1568

  • 在最后一层,并没有进行relu激活以及接入softmax,是因为,在CrossEntropyLoss中会自动激活最后一层的输出以及softmax处理

​ ​ CrossEntropyLoss图参考:《PyTorch深度学习实践》完结合集

​ ​ 详细网络结构搭建说明,请移步:Pytorch线性规划模型 学习笔记(一)

3. 迭代训练

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5) # 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 统计当前epoch下的正确个数 loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,) print("we are done!")

​ ​ 【代码解析】

  • total_correct变量用于统计每个epoch下正确预测值的个数,每进行epoch进行一次清零
  • torch.argmax(y_hat, dim=1)用于选取y_hat下每一行的最大值(每个样本的最高得分),并返回与y相同维度的tensor
  • torch.eq(y_pre, y)用于比较两个矩阵元素是否相同,相同则返回True,不同则返回False,用于判断预测值与真实值是否相同
  • torch.save保存了每个epoch的网络权重参数

4. 测试集预测部分

# 测试模型,测试集为test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("---------------预测分析---------------")
print("测试集样本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval() total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)

​ ​ 经过20个epoch的训练,在测试集上达到了98.590%的准确率,部分batch真实值与预测值展示如下:


5. 全部代码

链接:链接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw

提取码:82l4

转载请说明出处

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)的更多相关文章

  1. pytorch CNN 手写数字识别

    一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...

  2. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  3. muduo网络库学习笔记(三)TimerQueue定时器队列

    目录 muduo网络库学习笔记(三)TimerQueue定时器队列 Linux中的时间函数 timerfd简单使用介绍 timerfd示例 muduo中对timerfd的封装 TimerQueue的结 ...

  4. Python 工匠:使用数字与字符串的技巧学习笔记

    #Python 工匠:使用数字与字符串的技巧学习笔记#https://github.com/piglei/one-python-craftsman/blob/master/zh_CN/3-tips-o ...

  5. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  6. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  8. TensorFlow学习笔记(三)MNIST数字识别问题

    一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...

  9. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. 『政善治』Postman工具 — 10、Postman中对Cookie的操作

    目录 1.往常的Cookie处理方式 2.Postman中的Cookie管理机制 3.自定义Cookie管理内容 在接口测试中,某些接口的调用,需要带入已有Cookie,比如有些接口需要登陆后才能访问 ...

  2. C#中的partial关键字

    这节讲一下partial(局部的,部分的)关键字,初学者可能没有接触过这个关键字,但是只要你写过winform或者WPF应用程序的话,那你肯定被动用过这个关键字.首先介绍一下这个关键字的作用,它用作定 ...

  3. Web安全之PHP反序列化漏洞

    漏洞原理: 序列化可以将对象变成可以传输的字符串,方便数据保存传输,反序列化就是将字符串还原成对象.如果web应用没有对用户输入的反序列化字符串进行检测,导致反序列化过程可以被控制,就会造成代码执行, ...

  4. ruby基础(一)

    Ruby基础 1.对象.变量和常量 1.1 对象 在Ruby中表示数据的最基本单位是对象,任何数据都是对象,使用类来表示对象的种类. 一个某个类的对象称作对象的实例. 对象 类 eg 数值 Numer ...

  5. Linux_LVM管理

    一.Ivm的应用场景及其弊端 1.应用场景: 随着公司的发展,数据增长较快,最初规划的磁盘容量不够用了 2.弊端: 数据不是直接存放在硬盘上,而是在硬盘的.上面又虚拟出来--层逻辑卷存放数据,故而增加 ...

  6. 马哥Linux SysAdmin学习笔记(一)

    Linux入门 Linux系统管理: 磁盘管理,文件系统管理 RAID基础原理,LVM2 网络管理:TCP/IP协议,Linux网络属性配置 程序包管理:rpm,yum 进程管理:htop,glanc ...

  7. 1.1Ubuntu安装

    在虚拟机中安装 Ubuntu 步骤 安装前的准备和基本安装 设置语言环境 安装常用软件 1. 安装前的准备和基本安装 1.1 安装前的准备 访问 http://cn.ubuntu.com/downlo ...

  8. Python 简单的龟鱼游戏

    游戏编程:按一下要求定义一个乌龟类和鱼类并尝试编程 假设游戏场景为范围(x,y)为 0<=x<=10,0<=y<=10 游戏生成1只乌龟和10条鱼 他们的移动方向均随机 乌龟的 ...

  9. linux进阶之Tomcat服务篇

    一.Tomcat简介 Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选. Tomca ...

  10. IIC通信时遇到问题的解决

    如果遇到问题,反复查不到 就DEBUG  下单点运行,执行每一个SCK 和SDA的拉高拉低 看看是否能正常的拉高拉低 先解决掉底层的GPIO的控制问题, 有的时候可能数据引脚为特殊功能引脚