设计的CNN模型包括一个输入层,输入的是MNIST数据集中28*28*1的灰度图

两个卷积层,

第一层卷积层使用6个3*3的kernel进行filter,步长为1,填充1.这样得到的尺寸是(28+1*2-3)/1+1=28,即6个28*28的feature map

在后面进行池化,尺寸变为14*14

第二层卷积层使用16个5*5的kernel,步长为1,无填充,得到(14-5)/1+1=10,即16个10*10的feature map

池化后尺寸为5*5

后面加两层全连接层,第一层将16*5*5=400个神经元线性变换为120个,第二层将120个变为84个

最后的输出层将84个输出为10个种类

代码如下:

###MNIST数据集上卷积神经网络的简单实现###

# 配置库
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets # 配置参数
torch.manual_seed(1) # 设置随机数种子,确保结果可重复
batch_size = 128 # 批处理大小
learning_rate = 1e-2 # 学习率
num_epoches = 10 # 训练次数 # 加载MNIST数据
# 下载训练集MNIST手写数字训练集
train_dataset = datasets.MNIST(
root='./data', # 数据保持的位置
train=True, # 训练集
transform=transforms.ToTensor(), # 一个取值范围是【0,255】的PIL.Image
# 转化成取值范围是[0,1.0]的torch.FloatTensor
download=True
)
test_dataset = datasets.MNIST(
root='./data',
train=False, # 测试集
transform=transforms.ToTensor()
)
# 数据的批处理中,尺寸大小为batch_size
# 在训练集中,shuffle必须设置为True,表示次序是随机的
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 创建CNN模型
# 使用一个类来创建,这个模型包括1个输入层,2个卷积层,2个全连接层和1个输出层。
# 其中卷积层构成为卷积(conv2d)->激励函数(ReLU)->池化(MaxPooling)
# 全连接层由线性层(Linear)构成 # 定义卷积神经网络模型
class Cnn(nn.Module):
def __init__(self, in_dim, n_class): # 28*28*1
super(Cnn, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 6, 3, stride=1, padding=1), # 28*28
nn.ReLU(True),
nn.MaxPool2d(2, 2), # 14*14
nn.Conv2d(6, 16, 5, stride=1, padding=0), # 10*10*16
nn.ReLU(True),
nn.MaxPool2d(2, 2) # 5*5*16
)
self.fc = nn.Sequential(
nn.Linear(400, 120),
nn.Linear(120, 84),
nn.Linear(84, n_class)
) def forward(self, x):
out = self.conv(x)
out = out.view(out.size(0), 400) # 400=5*5*16
out = self.fc(out)
return out # 图片大小是28*28,10是数据的种类
model = Cnn(1, 10)
# 打印模型,呈现网络结构
print(model) # 模型训练,将img\label都用Variable包装起来,放入model中计算out,最后计算loss和正确率 # 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 开始训练
for epoch in range(num_epoches):
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(train_loader, 1): # 批处理
img, label = data
img = Variable(img)
label = Variable(label)
# 前向传播
out = model(img)
loss = criterion(out, label) # loss
running_loss += loss.item() * label.size(0)
# total loss,由于loss是batch取均值的,需要把batch_size乘进去
_, pred = torch.max(out, 1) # 预测结果
num_correct = (pred == label).sum() # 正确结果的数量
#accuracy = (pred == label).float().mean() # 正确率
running_acc += num_correct.item() # 正确结果的总数
# 后向传播
optimizer.zero_grad() # 梯度清零,以免影响其他batch
loss.backward() # 后向传播,计算梯度
optimizer.step() # 利用梯度更新W,b参数 # 打印一个循环后,训练集合上的loss和正确率
print('Train {} epoch, Loss:{:.6f},Acc:{:.6f}'.format(epoch + 1, running_loss / (len(train_dataset)),
running_acc / (len(train_dataset)))) # 在测试集上测试识别率
# 模型测试
model.eval()
# 由于训练和测试BatchNorm,Dropout配置不同,需要说明是否模型测试
eval_loss = 0
eval_acc = 0
for data in test_loader: # test set批处理
img, label = data
with torch.no_grad():
img = Variable(img)
# volatile确定你是否调用.backward(),
# 测试中不需要label=Variable(label,volatile=True)
#不需要梯度更新改为with torch.no_grad()
out = model(img)
loss = criterion(out, label) # 计算loss
eval_loss += loss.item() * label.size(0) # total loss
_, pred = torch.max(out, 1) # 预测结果
num_correct = (pred == label).sum() # 正确结果
eval_acc += num_correct.item() # 正确结果总数
print('Test loss:{:.6f},Acc:{:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc * 1.0 / (len(test_dataset))))

  

MNIST数据集上卷积神经网络的简单实现(使用PyTorch)的更多相关文章

  1. 基于MNIST数据的卷积神经网络CNN

    基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...

  2. pytorch实现MLP并在MNIST数据集上验证

    写在前面 由于MLP的实现框架已经非常完善,网上搜到的代码大都大同小异,而且MLP的实现是deeplearning学习过程中较为基础的一个实验.因此完全可以找一份源码以参考,重点在于照着源码手敲一遍, ...

  3. TensorFlow+实战Google深度学习框架学习笔记(12)------Mnist识别和卷积神经网络LeNet

    一.卷积神经网络的简述 卷积神经网络将一个图像变窄变长.原本[长和宽较大,高较小]变成[长和宽较小,高增加] 卷积过程需要用到卷积核[二维的滑动窗口][过滤器],每个卷积核由n*m(长*宽)个小格组成 ...

  4. 嵌入式设备上卷积神经网络推理时memory的优化

    以前的神经网络几乎都是部署在云端(服务器上),设备端采集到数据通过网络发送给服务器做inference(推理),结果再通过网络返回给设备端.如今越来越多的神经网络部署在嵌入式设备端上,即inferen ...

  5. 利用mnist数据集进行深度神经网络

    初始神经网络 这里要解决的问题是,将手写数字的灰度图像(28 像素 x28 像素)划分到 10 个类别中(0~9).我们将使用 MINST 数据集,它是机器学习领域的一个经典数据集,其历史几乎和这个领 ...

  6. TensorFlow技术解析与实战学习笔记(13)------Mnist识别和卷积神经网络AlexNet

    一.AlexNet:共8层:5个卷积层(卷积+池化).3个全连接层,输出到softmax层,产生分类. 论文中lrn层推荐的参数:depth_radius = 4,bias = 1.0 , alpha ...

  7. 使用CIFAR-10样本数据集测试卷积神经网络(ConvolutionalNeuralNetwork,CNN)

    第一次将例程跑起来了,有些兴趣. 参考的是如下URL: http://www.yidianzixun.com/article/0KNz7OX1 本来是比较Keras和Tensorflow的,我现在的水 ...

  8. 3层-CNN卷积神经网络预测MNIST数字

    3层-CNN卷积神经网络预测MNIST数字 本文创建一个简单的三层卷积网络来预测 MNIST 数字.这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成. MNIST 由 ...

  9. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

随机推荐

  1. 初学JavaScript正则表达式(九)

    分组:可以用 ( ) 来进行分组 一.Byron重复三次             Byron{3} --------- Byronnn 只是将紧挨着量词的字符重复            (Byron) ...

  2. Windows开机自动登录账户

    如何在Windows设了账户密码的情况下开机自动登录账户,有以下两种方法. 通过Windows设置自动登录 按“Win+R”组合键打开“运行”框内输入“netplwiz”. 打开以下窗口,将“要使用本 ...

  3. Angular命令和基础操作

    本文档假设你已经熟悉了 HTML,CSS,JavaScript和来自最新标准的一些知识,比如类和模块. 一.Angular命令 命令语法: 大多数命令以及少量选项,会有别名.别名会显示在每个命令的语法 ...

  4. ​LeetCode 26:删除排序数组中的重复项 Remove Duplicates from Sorted Array

    给定一个排序数组,你需要在原地删除重复出现的元素,使得每个元素只出现一次,返回移除后数组的新长度. 不要使用额外的数组空间,你必须在原地修改输入数组并在使用 O(1) 额外空间的条件下完成. Give ...

  5. 【前端知识体系-JS相关】深入理解MVVM和VUE

    1. v-bind和v-model的区别? v-bind用来绑定数据和属性以及表达式,缩写为':' v-model使用在表单中,实现双向数据绑定的,在表单元素外使用不起作用 2. Vue 中三要素的是 ...

  6. Springboot概述

    目录 什么是springboot Springboot的优点 SpringBoot的缺点 一:什么是springboot Springboot是Spring开源组织下的子项目,是Spring组件一站式 ...

  7. jvm的组成入门

    JVM的组成分为整体组成部分和运行时数据区组成部分. JVM的整体组成 JVM的整体组成可以分为4个部分:类加载器(Classloader).运行时数据区(Runtime Data Area).执行引 ...

  8. 初探云原生应用管理(二): 为什么你必须尽快转向 Helm v3

    系列介绍:这个系列是介绍如何用云原生技术来构建.测试.部署.和管理应用的内容专辑.做这个系列的初衷是为了推广云原生应用管理的最佳实践,以及传播开源标准和知识.在这个系列文章的开篇初探云原生应用管理(一 ...

  9. Window权限维持(五):屏幕保护程序

    屏幕保护是Windows功能的一部分,使用户可以在一段时间不活动后放置屏幕消息或图形动画.众所周知,Windows的此功能被威胁参与者滥用为持久性方法.这是因为屏幕保护程序是具有.scr文件扩展名的可 ...

  10. Elasticsearch 7.x从入门到精通

    Elasticsearch是一个分布式.可扩展.近实时的搜索与数据分析引擎,它能从项目一开始就赋予你的数据以搜索.分析和探索的能力. 通过本专栏的学习,你可以了解到,Elasticsearch在互联网 ...