用PyTorch完成手写数字识别

 import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets batch_size = 128
learning_rate = 0.01
num_epoch = 10 # 实例化MNIST数据集对象
train_data = datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST('./dataset', train=False, transform=transforms.ToTensor(), download=True) # train_loader:以batch_size大小的样本组为单位的可迭代对象
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data) class CNN(nn.Module):
def __init__(self, in_dim, out_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_dim, 6, 3, stride=1, padding=1)
self.batch_norm1 = nn.BatchNorm2d(6)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
self.pool = nn.MaxPool2d(2, 2)
self.batch_norm2 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, out_dim) def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = F.relu(x)
x = self.pool(x)
x = self.batch_norm2(self.conv2(x))
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def print_model_name(self):
print("Model Name: CNN") class Cnn(nn.Module):
def __init__(self, in_dim, n_class):
super(Cnn, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) self.fc = nn.Sequential(
nn.Linear(400, 120), nn.Linear(120, 84), nn.Linear(84, n_class)) def forward(self, x):
# print(x.size()) torch.Size([1024, 1, 28, 28])
out = self.conv(x)
out = out.view(out.size(0), -1)
# print(out.size()) = torch.Size([1024, 400])
out = self.fc(out)
# print(out.size()) torch.Size([1024, 10])
return out def print_model_name(self):
print("Model Name: Cnn") isGPU = torch.cuda.is_available()
print(isGPU)
model = CNN(1, 10)
if isGPU:
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
running_acc = 0.0
running_loss = 0.0
for i, data in enumerate(train_loader, 1): # train_loader:以batch_size大小的样本组为单位的可迭代对象
img, label = data
img = Variable(img)
label = Variable(label)
if isGPU:
img = img.cuda()
label = label.cuda()
# forward
out = model(img)
loss = criterion(out, label)
# print(label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step() _, pred = torch.max(out, dim=1) # 按维度dim 返回最大值
running_loss += loss.item()*label.size(0)
current_num = (pred == label).sum() # variable
acc = (pred == label).float().mean() # variable
running_acc += current_num.item() if i % 100 == 0:
print("epoch: {}/{}, loss: {:.6f}, running_acc: {:.6f}"
.format(epoch+1, num_epoch, loss.item(), acc.item()))
print("epoch: {}, loss: {:.6f}, accuracy: {:.6f}".format(epoch+1, running_loss, running_acc/len(train_data))) model.eval()
current_num = 0
for i , data in enumerate(test_loader, 1):
img, label = data
if isGPU:
img = img.cuda()
label = label.cuda()
with torch.no_grad():
img = Variable(img)
label = Variable(label)
out = model(img)
_, pred = torch.max(out, 1)
current_num += (pred == label).sum().item() print("Test result: accuracy: {:.6f}".format(float(current_num/len(test_data)))) torch.save(model.state_dict(), './cnn.pth') # 保存模型

Task7.手写数字识别的更多相关文章

  1. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  2. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  3. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  4. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  5. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  6. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  9. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

随机推荐

  1. 阶段3 1.Mybatis_10.JNDI扩展知识_2 补充-JNDI搭建maven的war工程

    使用骨架 src下创建test目录 再新建java和resources两个Directory test下面创建java java的目录,让他作为源码的跟目录 test下的java文件夹 选择 完成之后 ...

  2. 中国MOOC_零基础学Java语言_第7周 函数_2完数

    2 完数(5分) 题目内容: 一个正整数的因子是所有可以整除它的正整数.而一个数如果恰好等于除它本身外的因子之和,这个数就称为完数.例如6=1+2+3(6的因子是1,2,3). 现在,你要写一个程序, ...

  3. wpf 收集的不错的datagrid样式

    <ResourceDictionary xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation" x ...

  4. 6.824 Lab 3: Fault-tolerant Key/Value Service 3B

    Part B: Key/value service with log compaction Do a git pull to get the latest lab software. As thing ...

  5. 远程桌面 使用 本地输入法(虚拟化 终端 远程接入 RemoteApp)

    远程桌面连接组件是微软从Windows 2000 Server开始提供的,该组件一经推出便受到了很多用户的拥护和使用.   在WINDOWS XP和WINDOWS SERVER 2003中微软公司将该 ...

  6. robot framework python3环境下学习笔记(1)——安装robot framework

    安装环境:win10 64位,python3.6 1,安装robot framework pip install robotframework 2,安装wxPython pip install wxP ...

  7. tensorflow学习之tf.truncated_normal和tf.random_noraml的区别

    tf版本1.13.1,CPU 最近在tf里新学了一个函数,一查发现和tf.random_normal差不多,于是记录一下.. 1.首先是tf.truncated_normal函数 tf.truncat ...

  8. python+selenium控制浏览器窗口(刷新、前进、后退、退出浏览器)

    调用说明: driver.属性值 变量说明: 1.driver.current_url:用于获得当前页面的URL 2.driver.title:用于获取当前页面的标题 3.driver.page_so ...

  9. Java实验报告(一)

    1.水仙花数 public class test1{ public static void main(String[] args){ for(int num=100;num<1000;num++ ...

  10. 高效编程之 多线程Event

    Event 简介 Event 事件 是线程间通信的最简单方法之一,主要用于线程同步. 处理机制 定义一个全局内置标志Flag,如果Flag为False,执行到 event.wait 时程序就会阻塞,如 ...