用PyTorch完成手写数字识别

 import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets batch_size = 128
learning_rate = 0.01
num_epoch = 10 # 实例化MNIST数据集对象
train_data = datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST('./dataset', train=False, transform=transforms.ToTensor(), download=True) # train_loader:以batch_size大小的样本组为单位的可迭代对象
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data) class CNN(nn.Module):
def __init__(self, in_dim, out_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_dim, 6, 3, stride=1, padding=1)
self.batch_norm1 = nn.BatchNorm2d(6)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
self.pool = nn.MaxPool2d(2, 2)
self.batch_norm2 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, out_dim) def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = F.relu(x)
x = self.pool(x)
x = self.batch_norm2(self.conv2(x))
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def print_model_name(self):
print("Model Name: CNN") class Cnn(nn.Module):
def __init__(self, in_dim, n_class):
super(Cnn, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) self.fc = nn.Sequential(
nn.Linear(400, 120), nn.Linear(120, 84), nn.Linear(84, n_class)) def forward(self, x):
# print(x.size()) torch.Size([1024, 1, 28, 28])
out = self.conv(x)
out = out.view(out.size(0), -1)
# print(out.size()) = torch.Size([1024, 400])
out = self.fc(out)
# print(out.size()) torch.Size([1024, 10])
return out def print_model_name(self):
print("Model Name: Cnn") isGPU = torch.cuda.is_available()
print(isGPU)
model = CNN(1, 10)
if isGPU:
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
running_acc = 0.0
running_loss = 0.0
for i, data in enumerate(train_loader, 1): # train_loader:以batch_size大小的样本组为单位的可迭代对象
img, label = data
img = Variable(img)
label = Variable(label)
if isGPU:
img = img.cuda()
label = label.cuda()
# forward
out = model(img)
loss = criterion(out, label)
# print(label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step() _, pred = torch.max(out, dim=1) # 按维度dim 返回最大值
running_loss += loss.item()*label.size(0)
current_num = (pred == label).sum() # variable
acc = (pred == label).float().mean() # variable
running_acc += current_num.item() if i % 100 == 0:
print("epoch: {}/{}, loss: {:.6f}, running_acc: {:.6f}"
.format(epoch+1, num_epoch, loss.item(), acc.item()))
print("epoch: {}, loss: {:.6f}, accuracy: {:.6f}".format(epoch+1, running_loss, running_acc/len(train_data))) model.eval()
current_num = 0
for i , data in enumerate(test_loader, 1):
img, label = data
if isGPU:
img = img.cuda()
label = label.cuda()
with torch.no_grad():
img = Variable(img)
label = Variable(label)
out = model(img)
_, pred = torch.max(out, 1)
current_num += (pred == label).sum().item() print("Test result: accuracy: {:.6f}".format(float(current_num/len(test_data)))) torch.save(model.state_dict(), './cnn.pth') # 保存模型

Task7.手写数字识别的更多相关文章

  1. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  2. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  3. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  4. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  5. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  6. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  9. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

随机推荐

  1. 【Bean】标签常用属性

    [Bean]标签常用属性 Id 说明:起名称,id属性值名称任意,不能包含特殊符号,根据id得到配置对象. Class 说明:创建对象所在类的全路径. Name 说明:功能和id是一样的,id属性值不 ...

  2. strtoul()要优于atoi()函数---C语言

    strtoul():将字符串转为长整型整数 atoi():将字符串转为整型整数 在32位STM32中,int是32位的,如果字符串是“3123456789”,大于0x7fff fff,用atoi()函 ...

  3. Mybatis-学习笔记(4)1对1、1对多、多对多

    1.1对1 有2种方式对内嵌Bean设值: 1>关联查询就一条语句.使用association关键字,直接将嵌套对象的映射表的字段赋值内嵌对象. <association property ...

  4. 【SSL2325】最小转弯问题

    题面: \[\Large\text{最小转弯问题}\] \[Time~Limit:1000MS~~Memory~Limit:65536K\] Description 给出一张地图,这张地图被分为 n× ...

  5. [项目实战]训练retinanet(pytorch版)

    采用github上star比较高的一个开源实现https://github.com/yhenon/pytorch-retinanet 在anaconda中新建了一个环境,因为一开始并没有新建环境,在原 ...

  6. 使用form表单提交请求如何获取后台返回的数据?

    问题描述 一般的form表单提交是单向的:只能给服务器发送数据,但是无法获取服务器返回的数据,也就是无法读取HTTP应答包. 想要真正的半双工通讯一般需要使用Ajax, 但是Ajax对文件传输也很麻烦 ...

  7. 关于“如何只用2GB内存从20亿,40亿,80亿个整数中找到出现次数最多的数?”的一种思路

    小弟不才,只懂一些c#的皮毛,有一些想法, int32值范围大概在-20亿——20亿,按hashtable一个keyvalue占8B的设定来说,最大可以存储大约2.5亿个 数字-次数对. 那么,可以将 ...

  8. es5继承和es6类和继承

    es6新增关键字class,代表类,其实相当于代替了es5的构造函数 通过构造函数可以创建一个对象实例,那么通过class也可以创建一个对象实列 /* es5 创建一个person 构造函数 */ f ...

  9. HTML回顾之表格

    HTML表格 由什么组成? 表格由<table>标签来定义.每个表格有若干行(<tr>标签来定义),每行被分割成若干单元格(<td>标签来定义). td值表格数据, ...

  10. Nginx中配置https中引用http的问题

    Nginx中配置https中引用http的问题 遇到问题: 今天公司要在后台增加直播入口,使用腾讯云的实时音视频,要求是必须使用https,在配置完强制跳转https候,发现后台无法上传图片,在浏览器 ...