用PyTorch完成手写数字识别

 import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets batch_size = 128
learning_rate = 0.01
num_epoch = 10 # 实例化MNIST数据集对象
train_data = datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST('./dataset', train=False, transform=transforms.ToTensor(), download=True) # train_loader:以batch_size大小的样本组为单位的可迭代对象
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data) class CNN(nn.Module):
def __init__(self, in_dim, out_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_dim, 6, 3, stride=1, padding=1)
self.batch_norm1 = nn.BatchNorm2d(6)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
self.pool = nn.MaxPool2d(2, 2)
self.batch_norm2 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, out_dim) def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = F.relu(x)
x = self.pool(x)
x = self.batch_norm2(self.conv2(x))
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def print_model_name(self):
print("Model Name: CNN") class Cnn(nn.Module):
def __init__(self, in_dim, n_class):
super(Cnn, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) self.fc = nn.Sequential(
nn.Linear(400, 120), nn.Linear(120, 84), nn.Linear(84, n_class)) def forward(self, x):
# print(x.size()) torch.Size([1024, 1, 28, 28])
out = self.conv(x)
out = out.view(out.size(0), -1)
# print(out.size()) = torch.Size([1024, 400])
out = self.fc(out)
# print(out.size()) torch.Size([1024, 10])
return out def print_model_name(self):
print("Model Name: Cnn") isGPU = torch.cuda.is_available()
print(isGPU)
model = CNN(1, 10)
if isGPU:
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
running_acc = 0.0
running_loss = 0.0
for i, data in enumerate(train_loader, 1): # train_loader:以batch_size大小的样本组为单位的可迭代对象
img, label = data
img = Variable(img)
label = Variable(label)
if isGPU:
img = img.cuda()
label = label.cuda()
# forward
out = model(img)
loss = criterion(out, label)
# print(label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step() _, pred = torch.max(out, dim=1) # 按维度dim 返回最大值
running_loss += loss.item()*label.size(0)
current_num = (pred == label).sum() # variable
acc = (pred == label).float().mean() # variable
running_acc += current_num.item() if i % 100 == 0:
print("epoch: {}/{}, loss: {:.6f}, running_acc: {:.6f}"
.format(epoch+1, num_epoch, loss.item(), acc.item()))
print("epoch: {}, loss: {:.6f}, accuracy: {:.6f}".format(epoch+1, running_loss, running_acc/len(train_data))) model.eval()
current_num = 0
for i , data in enumerate(test_loader, 1):
img, label = data
if isGPU:
img = img.cuda()
label = label.cuda()
with torch.no_grad():
img = Variable(img)
label = Variable(label)
out = model(img)
_, pred = torch.max(out, 1)
current_num += (pred == label).sum().item() print("Test result: accuracy: {:.6f}".format(float(current_num/len(test_data)))) torch.save(model.state_dict(), './cnn.pth') # 保存模型

Task7.手写数字识别的更多相关文章

  1. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  2. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  3. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  4. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  5. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  6. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  9. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

随机推荐

  1. php7.2 下安装yaf扩展

    wget http://pecl.php.net/get/yaf-3.0.7.tgz  解压并进入目录: 1 tar -zxvf yaf-3.0.7* && cd yaf-3.0.7  ...

  2. 一文学会Go - 2 数据结构与算法实践篇

    练习:使用go语言实现冒泡排序和归并排序 冒泡排序是所有排序算法中最简单的,练习时先实现它: func bubbleSort(array []int) { n := len(array) ; j &l ...

  3. session 、cookie、token的区别(转)

    session  session的中文翻译是“会话”,当用户打开某个web应用时,便与web服务器产生一次session.服务器使用session把用户的信息临时保存在了服务器上,用户离开网站后ses ...

  4. 书籍:wpf学习书籍介绍

    WPF参考书推荐 下面先整理下,本人主要学习的WPF参考书: 1.WPF编程宝典(C#2010) 该书:(必读) 心得体会:读完该书后,你对WPF的基础和基本控件的使用,包括WPF的编程模型,相比Wi ...

  5. linux 正则表达式 元字符

    \b 单词边界 \bcool\b  只匹配cool字符串 [root@MongoDB ~]# cat test.txt i am mike1 i am mike i am mike12 匹配有mike ...

  6. Java第一周总结

    通过两周的Java学习最深刻的体会就是Java好像要比C要简单一些. 然后这两周我学习到了很多东西,李老师第一次上课就给我们介绍了Java的发展历程,同时还有Java工具jdk的发展历程. Java语 ...

  7. [Jupyter Notebook] 01 这么多快捷键,我可顶不住!先记个八成吧

    0. 一些说明 为了入门 Python3 安装了 Anaconda,它集成了 Jupyter Notebook 1. 调出快捷键表 打开 Jupyter Notebook,新建一个 Python3(我 ...

  8. android app开发中的常用组件

    1 Activity 1.1 Activity的启动 第一,android manifest中指定的主activity,点击app的icon启动进入. 第二,使用intent,在另外一个activit ...

  9. Javascript原型介绍

    原型及原型链 原型基础概念 function Person () { this.name = 'John'; } var person = new Person(); Person.prototype ...

  10. PHP数据结构基本概念

    原文:https://www.cnblogs.com/crystaltu/p/6408484.html 学习任何一种技术都应该先清楚它的基本概念,这是学习任何知识的起点!本文是讲述数据结构的基本概念, ...