本篇不涉及模型原理,只是分享下代码。想要了解模型原理的可以去看网上很多大牛的博客。

目前代码实现了CNN和LSTM两个网络,整个代码分为四部分:

  • Config:项目中涉及的参数;

  • CNN:卷积神经网络结构;

  • LSTM:长短期记忆网络结构;

  • TrainProcess

    模型训练及评估,参数model控制训练何种模型(CNN or LSTM)。

完整代码

Talk is cheap, show me the code.

# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-04-05
# @version: python3.7 import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime class Config:
batch_size = 64
epoch = 10
alpha = 1e-3 print_per_step = 100 # 控制输出 class CNN(nn.Module): def __init__(self):
super(CNN, self).__init__()
"""
Conv2d参数:
第一位:input channels 输入通道数
第二位:output channels 输出通道数
第三位:kernel size 卷积核尺寸
第四位:stride 步长,默认为1
第五位:padding size 默认为0,不补
"""
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2, 2)
) self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
) self.fc1 = nn.Sequential(
nn.Linear(64 * 5 * 5, 128),
nn.BatchNorm1d(128),
nn.ReLU()
) self.fc2 = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64), # 加快收敛速度的方法(注:批标准化一般放在全连接层后面,激活函数层的前面)
nn.ReLU()
) self.fc3 = nn.Linear(64, 10) def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__() self.lstm = nn.LSTM(
input_size=28,
hidden_size=64,
num_layers=1,
batch_first=True,
) self.output = nn.Linear(64, 10) def forward(self, x):
r_out, (_, _) = self.lstm(x, None) out = self.output(r_out[:, -1, :])
return out class TrainProcess: def __init__(self, model="CNN"):
self.train, self.test = self.load_data()
self.model = model
if self.model == "CNN":
self.net = CNN()
elif self.model == "LSTM":
self.net = LSTM()
else:
raise ValueError('"CNN" or "LSTM" is expected, but received "%s".' % model)
self.criterion = nn.CrossEntropyLoss() # 定义损失函数
self.optimizer = optim.Adam(self.net.parameters(), lr=Config.alpha) @staticmethod
def load_data():
print("Loading Data......")
"""加载MNIST数据集,本地数据不存在会自动下载"""
train_data = datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=True) test_data = datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor()) # 返回一个数据迭代器
# shuffle:是否打乱顺序
train_loader = torch.utils.data.DataLoader(dataset=train_data,
batch_size=Config.batch_size,
shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_data,
batch_size=Config.batch_size,
shuffle=False)
return train_loader, test_loader def train_step(self):
steps = 0
start_time = datetime.now() print("Training & Evaluating based on '%s'......" % self.model)
for epoch in range(Config.epoch):
print("Epoch {:3}.".format(epoch + 1)) for data, label in self.train:
data, label = Variable(data.cpu()), Variable(label.cpu())
# LSTM输入为3维,CNN输入为4维
if self.model == "LSTM":
data = data.view(-1, 28, 28)
self.optimizer.zero_grad() # 将梯度归零
outputs = self.net(data) # 将数据传入网络进行前向运算
loss = self.criterion(outputs, label) # 得到损失函数
loss.backward() # 反向传播
self.optimizer.step() # 通过梯度做一步参数更新 # 每100次打印一次结果
if steps % Config.print_per_step == 0:
_, predicted = torch.max(outputs, 1)
correct = int(sum(predicted == label)) # 计算预测正确个数
accuracy = correct / Config.batch_size # 计算准确率
end_time = datetime.now()
time_diff = (end_time - start_time).seconds
time_usage = '{:3}m{:3}s'.format(int(time_diff / 60), time_diff % 60)
msg = "Step {:5}, Loss:{:6.2f}, Accuracy:{:8.2%}, Time usage:{:9}."
print(msg.format(steps, loss, accuracy, time_usage)) steps += 1 test_loss = 0.
test_correct = 0
for data, label in self.test:
data, label = Variable(data.cpu()), Variable(label.cpu())
if self.model == "LSTM":
data = data.view(-1, 28, 28)
outputs = self.net(data)
loss = self.criterion(outputs, label)
test_loss += loss * Config.batch_size
_, predicted = torch.max(outputs, 1)
correct = int(sum(predicted == label))
test_correct += correct accuracy = test_correct / len(self.test.dataset)
loss = test_loss / len(self.test.dataset)
print("Test Loss: {:5.2f}, Accuracy: {:6.2%}".format(loss, accuracy)) end_time = datetime.now()
time_diff = (end_time - start_time).seconds
print("Time Usage: {:5.2f} mins.".format(time_diff / 60.)) if __name__ == "__main__":
p = TrainProcess(model='CNN')
p.train_step()

Peace~~

基于PyTorch实现MNIST手写字识别的更多相关文章

  1. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  4. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  5. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  6. 用TensorFlow教你手写字识别

    博主原文链接:用TensorFlow教你做手写字识别(准确率94.09%) 如需转载,请备注出处及链接,谢谢. 2012 年,Alex Krizhevsky, Geoff Hinton, and Il ...

  7. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

  8. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  9. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

随机推荐

  1. 在线热备份数据库之innobackupex 增量备份InnoDB

    在线热备份数据库之innobackupex 增量备份InnoDB 什么是增量备份?其原理是什么? 增量备份是基于上一次备份后对新增加的内容进行备份,优点相较于完整备份而言备份内容少时间短,能够节省磁盘 ...

  2. python-->二进制的用法

    1.10进制转换为其他进制 方法一:函数 十进制转二进制:bin(10) --> '0b1010' tpye:是字符串类型 0b:表示2进制 十进制转八进制:oct(10) --> '0o ...

  3. EffectiveJava-4

    一.基本类型优先于装箱基本类型 Java的基本数据类型也叫做内置类型是java语言本身提供的数据类型,是引用其他类型的基础.Java的基本数据类型分为:整数类型.浮点类型.字符类型.布尔类型这四个类型 ...

  4. Unix/Linux 从哪儿来?那些改变世界的人们...

    昨天看文章时发现自己对 linux 操作系统不够了解,还记得 17 年时听过老师的一些课,对 linux 的历史有一点了解,不过当时并没有记录笔记,现在已经忘的差不多了. 这次从网上找资料,又重新看了 ...

  5. 基于cookie的用户登录状态管理

    cookie是什么 先来花5分钟看完这篇文章:https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Cookies 看完上文,相信大家对cookie已经有 ...

  6. VM小技巧——虚拟机解决vm窗口太小的办法

    ——" 慢下来总结才能增大效率" 很多人在装虚拟机的时候,遇到了窗口过小不能自适应的问题.我也是查了好多资料,都说安装Vmware Tools即可解决,还有说修改分辨率也可以.两种 ...

  7. Android9.0 SystemUI 网络信号栏定制修改

    前情提要 Android 8.1平台SystemUI 导航栏加载流程解析 9.0 改动点简要说明 1.新增 StatusBarMobileView 替代 SignalClusterView,用以控制信 ...

  8. egret Tiledmap编写障碍物的思路

    egret Tiledmap编写障碍物的思路 获取控制对象下一刻移动的坐标,将其转换成瓦片坐标,如果getTileGIDAt(根据瓦片坐标获取瓦片id)的值不为0,说明对象将要移动的位置有障碍物,不做 ...

  9. hdu 2846 Repository (字典树)

    RepositoryTime Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)Total S ...

  10. MySQL数据库root账户密码忘记两种处理方法(保有效)

    方法1: 1.停止MySQL服务 # kill `cat /var/run/mysqld/mysqld.pid` 或者 # pkill mysqld 2.创建一个密码赋值语句的文本文件 # vi my ...