1. 准备数据集

1.1 MNIST数据集获取:

  • torchvision.datasets接口直接下载,该接口可以直接构建数据集,推荐

  • 其他途径下载后,编写程序进行读取,然后由Datasets构建自己的数据集

​ ​ 本文使用第一种方法获取数据集,并使用Dataloader进行按批装载。如果使用程序下载失败,请将其他途径下载的MNIST数据集 [文件][解压文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件结构图如下:

​ ​ 其中,model文件夹用来存储每个epoch训练的模型参数,根文件夹下包含model.py用于训练模型,test.py为测试集测试,show.py为展示部分

1.2 程序部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time # 1. 准备数据集
## 1.1 使用torchvision自动下载MNIST数据集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True) ## 1.2 构建数据集装载器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("===============数据统计===============")
print("训练集样本:",train_data.__len__(), train_data.data.shape)

​ ​ 【代码解析】

  • root为存放MNIST的路径,trian=True代表下载的为训练集和训练集标签,False则代表测试集和标签

  • transforms.ToTensor()表示将shape为(H, W, C)的 numpy 数组或 img 转为shape为(C, H, W)的tensor,并将数值归一化为[0,1]

  • download为True则代表自动下载,若该文件夹下已经下载,则直接跳过下载步骤

  • shuffle=True,表示对分好的batch进行洗牌操作,drop_last=True表示对最后不足batch大小的剩余样本舍去,False表示保留

  • num_works表示每次读取的进程数,和核心数有关

​ ​ Dataset和Dataloader详细说明,请移步:[Pytorch Dataset和Dataloader 学习笔记(二)]

2. 设计网络结构

2.1 网络设计

​ ​ 网络结构如上图所示,输入图像—>卷积1—>池化1—>卷积2—>池化2—>全连接1—>全连接2—>softmax,每次卷积通道数都增加一倍,最后送入全连接层实现分类

2.2 程序部分

# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10) def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss会自动激活最后一层的输出以及softmax处理
return y_hat net = Net() # 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

​ ​ 【代码解析】

  • fc1的1568维度是因为最后一次池化后的shape为32*7*7=1568

  • 在最后一层,并没有进行relu激活以及接入softmax,是因为,在CrossEntropyLoss中会自动激活最后一层的输出以及softmax处理

​ ​ CrossEntropyLoss图参考:《PyTorch深度学习实践》完结合集

​ ​ 详细网络结构搭建说明,请移步:Pytorch线性规划模型 学习笔记(一)

3. 迭代训练

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5) # 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 统计当前epoch下的正确个数 loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,) print("we are done!")

​ ​ 【代码解析】

  • total_correct变量用于统计每个epoch下正确预测值的个数,每进行epoch进行一次清零
  • torch.argmax(y_hat, dim=1)用于选取y_hat下每一行的最大值(每个样本的最高得分),并返回与y相同维度的tensor
  • torch.eq(y_pre, y)用于比较两个矩阵元素是否相同,相同则返回True,不同则返回False,用于判断预测值与真实值是否相同
  • torch.save保存了每个epoch的网络权重参数

4. 测试集预测部分

# 测试模型,测试集为test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("---------------预测分析---------------")
print("测试集样本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval() total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)

​ ​ 经过20个epoch的训练,在测试集上达到了98.590%的准确率,部分batch真实值与预测值展示如下:


5. 全部代码

链接:链接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw

提取码:82l4

转载请说明出处

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)的更多相关文章

  1. pytorch CNN 手写数字识别

    一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...

  2. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  3. muduo网络库学习笔记(三)TimerQueue定时器队列

    目录 muduo网络库学习笔记(三)TimerQueue定时器队列 Linux中的时间函数 timerfd简单使用介绍 timerfd示例 muduo中对timerfd的封装 TimerQueue的结 ...

  4. Python 工匠:使用数字与字符串的技巧学习笔记

    #Python 工匠:使用数字与字符串的技巧学习笔记#https://github.com/piglei/one-python-craftsman/blob/master/zh_CN/3-tips-o ...

  5. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  6. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  8. TensorFlow学习笔记(三)MNIST数字识别问题

    一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...

  9. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. 【报错】No converter found for return value of type: class java.util.HashMap

    ssm开发碰到的错误 @ResponseBody的作用是是将java对象转为json格式的数据 @ResponseBody注解标识该方法的返回值直接写回到HTTP响应体中去(而不会被被放置到Model ...

  2. Linux系统运行netstat命令时的过三关斩一将

    1.简介 这篇文章主要是记录在日常工作中遇到的一些问题,将其都总结整合到一起,方便查看,提高工作效率.小伙伴们看到标题可能觉得奇怪,不是过五关斩六将而是过三关斩一将.慢慢地往后看发现其中奥秘. 2.过 ...

  3. opencv——图像遍历以及像素操作

    摘要 我们在图像处理时经常会用到遍历图像像素点的方式,在OpenCV中一般有四种图像遍历的方式,在这里我们通过像素变换的点操作来实现对图像亮度和对比度的调整. 补充: 图像变换可以看成 像素变换--点 ...

  4. NtQuerySystemInformation获取进程/线程状态

    __kernel_entry NTSTATUS NtQuerySystemInformation( SYSTEM_INFORMATION_CLASS SystemInformationClass, P ...

  5. HEVC学习(一) —— HM的使用

    http://blog.csdn.net/hevc_cjl/article/details/8169182 首先自然是先把这个测试模型下载下来,链接地址如下:https://hevc.hhi.frau ...

  6. Beta_测试说明

    Beta阶段测试说明 测试发现的BUG Beta阶段测试BUG: 测试发现的BUG都放在BUG FIX里面 GitHUB issue BUG FIX 后端:实体识别结果重复. 解决:把处理结果的id和 ...

  7. 有哪些适合中小企业使用的PaaS平台?

    对于中小企业来说,在业务上同样需要工作流.应用平台来进行支持,但是,面对诸如ERP等动辄好几十万的费用来说,完全是在增加运营成本.如何解决中小企业对于业务应用.工作流管理的需求问题呢?使用PaaS低代 ...

  8. [bug] Maven每次都自动下载jar包非常慢

    解决 方法一:将maven改为离线模式,自己下载jar包复制到仓库中 eclipse中Window>preferences>maven>勾选Offline 方法二:将maven镜像改 ...

  9. Ansible_管理机密

    一.Ansible Vault 1.什么是Ansible Vault Ansible提供的Ansible Vault可以加密和解密任何由Ansible使用的结构化数据文件 若要使用Ansible Va ...

  10. 046.Python协程

    协程 1 生成器 初始化生成器函数 返回生成器对象,简称生成器 def gen(): for i in range(10): #yield 返回便能够保留状态 yield i mygen = gen( ...