MNIST 手写数字识别 卷积神经网络 Pytorch框架

谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打

说明

下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释,方便查看。

代码实现

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # Device configuration
#这里是个python的三元表达式,如果cuda存在的话,divice='cuda:0',否者就是'cpu'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Hyper parameters
num_epochs = 5 #全部训练集使用的次数
num_classes = 10 #全连接层输出的结果种类
batch_size = 100 #批处理的图片的个数
learning_rate = 0.001 #学习率,在梯度下降法里面的系数 # MNIST dataset
#下载训练数据集,位置放在本文件的父文件夹下的data文件夹里面,数据需要转换格式为Tensor
train_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
#下载测试集,位置放在放在本文件的父文件夹下的data文件夹里面,数据需要转换为Tensor格式
test_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=False,
transform=transforms.ToTensor()) # Data loader
#这里的shuffle(bool, optional):在每个epoch开始的时候,对数据进行重新打乱,就是重新分组
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
#对于测试集来说不需要进行从新分组(这里好像不是必须的,也可以试试每次测试分组,有意义吗?)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Convolutional neural network (two convolutional layers)
#定义一个卷积类,这里需要继承nn.Module,它是专门为神经网络设计的模块化接口
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
#调用父类的初始化函数
super(ConvNet, self).__init__()
#一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
self.layer1 = nn.Sequential(
#二维卷积层,输入通道数1,输出通道数16(相当于有16个filter,也就是16个卷积核),卷积核大小为5*5,步长为1,零填充2圈
#经过计算,可以得到卷积输出的图像的大小和输入的图像大小是等大小的,但是深度不一样,为28*28*16(16为深度),因为这里的padding抵消了卷积的缩小
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
#BatchNorm2d是卷积网络中防止梯度消失或爆炸的函数,参数是卷积的输出通道数
nn.BatchNorm2d(16),
#激活函数
nn.ReLU(),
#二维最大池化,核的大小为2,步长为2
#这样输出的图片大小就是14*14*16(16为深度)
nn.MaxPool2d(kernel_size=2, stride=2))
#两层的卷积网络,具体含义和上面相同
self.layer2 = nn.Sequential(
#这里大小也没有变化,输出依然和输出的大小相同,深度为32,所以图像为14*14*32
#但是这里的卷积核的数量是32,和输出通道数相同。
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
#下面经过池化后输出就会变成7*7*32
nn.MaxPool2d(kernel_size=2, stride=2))
#对输入数据做线性变换,第一个参数是每个输入样本的大小:7*7*32;第二个参数是输出样本的大小,这里是10,正好代表10个数,相当于类别
#第三个参数为bias(偏差),默认为True。如果为False,那么这层将不会学习偏置。
self.fc = nn.Linear(7*7*32, num_classes) #定义了每次执行的 计算步骤。 在所有的子类中都需要重写forward函数。
#
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
model = ConvNet(num_classes).to(device) # Loss and optimizer
#损失函数,
criterion = nn.CrossEntropyLoss()
#优化函数
#params (iterable)第一个参数:待优化参数的iterable或者是定义了参数组的dict
#lr (float, 可选) – 学习率(默认:1e-3)同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。
#较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Train the model
#total_step是每一轮的测试次数,这里就是60000/100=600次
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device) # Forward pass
#model(images)等价于module.forward(images)
outputs = model(images)
#根据输出的结果和标签对比,计算loss
loss = criterion(outputs, labels) # Backward and optimize
#根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;
#但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了
#将梯度初始化为零
optimizer.zero_grad()
#这里是使用反向传播计算梯度值
loss.backward()
#在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次
optimizer.step() if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item())) # Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

手写数字识别 卷积神经网络 Pytorch框架实现的更多相关文章

  1. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  4. 识别手写数字增强版100% - pytorch从入门到入道(一)

    手写数字识别,神经网络领域的“hello world”例子,通过pytorch一步步构建,通过训练与调整,达到“100%”准确率 1.快速开始 1.1 定义神经网络类,继承torch.nn.Modul ...

  5. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

  6. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. TensorFlow 卷积神经网络手写数字识别数据集介绍

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 手写数字识别 接下来将会以 MNIST 数据集为例,使用卷积层和池 ...

  9. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

随机推荐

  1. BZOJ 3732: Network Kruskal 重构树

    模板题,练练手~ Code: #include <cstdio> #include <algorithm> #define N 80000 #define setIO(s) f ...

  2. XFF和referer

    XFF构造来源IP Refer构造来源浏览器

  3. 误删系统服务Task Schedule的恢复方法

    cmd命令 sc query Schedule查询该服务是否存在 sc delete Schedule删除服务 sc create Schedule binpath= "C:\Windows ...

  4. LocalDate/LocalDateTime与String的互相转换示例(附DateTimeFormatter详解)

    摘自:https://www.jianshu.com/p/b7e72e585a37 LocalDate/LocalDateTime与String的互相转换示例(附DateTimeFormatter详解 ...

  5. C++入门经典-例6.20-修改string字符串的单个字符

    1:使用+可以将两个string 字符串连接起来.同时,string还支持标准输入输出函数.代码如下: // 6.20.cpp : 定义控制台应用程序的入口点. // #include "s ...

  6. 对于Java培训出身的同学,接下来该怎么学习技术?

    首先恭喜从培训班出来找到工作的同学,确实挺不容易的,4个月的培训,每天从早上9点到晚上9点,也是996,主要的活动地方就是宿舍和教室, 让我现在也去培训,我估计还熬不下来. 尤其是对于从小白开始的同学 ...

  7. Python学习笔记:数据的处理

    上次的学习中有个split函数,照着head first Python上敲一遍代码: >>> with open('james.txt') as jaf: data=jaf.read ...

  8. IDEA里面maven菜单解读

  9. latexdiff中的大坑:字符编码问题

    最近用latex写文章,要用到修订模式,于是采用latexdiff命令生成修订版pdf.这原本是一个非常简单方便的方法,却隐藏着字符编码的问题,初次用可能会遇到意想不到的问题,让人很烦,比如,生成出来 ...

  10. ddms 和 traceview 的区别?

    ddms 原意是:davik debug monitor service.简单的说 ddms 是一个程序执行查看器,在里面可以看见线程和堆栈等信息,traceView 是程序性能分析器.tracevi ...