Task7.手写数字识别
用PyTorch完成手写数字识别
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets batch_size = 128
learning_rate = 0.01
num_epoch = 10 # 实例化MNIST数据集对象
train_data = datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST('./dataset', train=False, transform=transforms.ToTensor(), download=True) # train_loader:以batch_size大小的样本组为单位的可迭代对象
train_loader = DataLoader(train_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data) class CNN(nn.Module):
def __init__(self, in_dim, out_dim):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_dim, 6, 3, stride=1, padding=1)
self.batch_norm1 = nn.BatchNorm2d(6)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
self.pool = nn.MaxPool2d(2, 2)
self.batch_norm2 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, out_dim) def forward(self, x):
x = self.batch_norm1(self.conv1(x))
x = F.relu(x)
x = self.pool(x)
x = self.batch_norm2(self.conv2(x))
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def print_model_name(self):
print("Model Name: CNN") class Cnn(nn.Module):
def __init__(self, in_dim, n_class):
super(Cnn, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2)) self.fc = nn.Sequential(
nn.Linear(400, 120), nn.Linear(120, 84), nn.Linear(84, n_class)) def forward(self, x):
# print(x.size()) torch.Size([1024, 1, 28, 28])
out = self.conv(x)
out = out.view(out.size(0), -1)
# print(out.size()) = torch.Size([1024, 400])
out = self.fc(out)
# print(out.size()) torch.Size([1024, 10])
return out def print_model_name(self):
print("Model Name: Cnn") isGPU = torch.cuda.is_available()
print(isGPU)
model = CNN(1, 10)
if isGPU:
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
running_acc = 0.0
running_loss = 0.0
for i, data in enumerate(train_loader, 1): # train_loader:以batch_size大小的样本组为单位的可迭代对象
img, label = data
img = Variable(img)
label = Variable(label)
if isGPU:
img = img.cuda()
label = label.cuda()
# forward
out = model(img)
loss = criterion(out, label)
# print(label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step() _, pred = torch.max(out, dim=1) # 按维度dim 返回最大值
running_loss += loss.item()*label.size(0)
current_num = (pred == label).sum() # variable
acc = (pred == label).float().mean() # variable
running_acc += current_num.item() if i % 100 == 0:
print("epoch: {}/{}, loss: {:.6f}, running_acc: {:.6f}"
.format(epoch+1, num_epoch, loss.item(), acc.item()))
print("epoch: {}, loss: {:.6f}, accuracy: {:.6f}".format(epoch+1, running_loss, running_acc/len(train_data))) model.eval()
current_num = 0
for i , data in enumerate(test_loader, 1):
img, label = data
if isGPU:
img = img.cuda()
label = label.cuda()
with torch.no_grad():
img = Variable(img)
label = Variable(label)
out = model(img)
_, pred = torch.max(out, 1)
current_num += (pred == label).sum().item() print("Test result: accuracy: {:.6f}".format(float(current_num/len(test_data)))) torch.save(model.state_dict(), './cnn.pth') # 保存模型
Task7.手写数字识别的更多相关文章
- C#中调用Matlab人工神经网络算法实现手写数字识别
手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化 投影 矩阵 目标定位 Matlab 手写数字图像识别简介: 手写 ...
- CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 机器学习(二)-kNN手写数字识别
一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...
- 利用神经网络算法的C#手写数字识别
欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
通过: 手写数字识别 ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别 ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...
随机推荐
- 全局namespace与模块内的namespace
declare global{ declare namespace xxx } 相当于 在一个js文件的顶级部分 declare namespace xxx 声明的都是全局的namespace, 如果 ...
- 网易云课堂_C++程序设计入门(下)_第8单元:年年岁岁花相似– 运算符重载_第8单元 - 作业2:OJ编程 - 重载数组下标运算符
第8单元 - 作业2:OJ编程 - 重载数组下标运算符 查看帮助 返回 温馨提示: 1.本次作业属于Online Judge题目,提交后由系统即时判分. 2.学生可以在作业截止时间之前不限次数提 ...
- springmvc中获取request对象,加载biz(service)的方法
获取request对象: 首先配置web.xml文件--> <listener> <listener-class> org.springframework.web.con ...
- mysql驱动表与被驱动表及join优化
驱动表与被驱动表 先了解在join连接时哪个表是驱动表,哪个表是被驱动表:1.当使用left join时,左表是驱动表,右表是被驱动表2.当使用right join时,右表时驱动表,左表是驱动表3.当 ...
- (ROT-13解密)Flare-On4: Challenge1 login.html
说是FlareOn的逆向 倒不如说是crypto....... 题目不难 F12看源码: document.getElementById("prompt").onclick = f ...
- [Web 前端] 032 vue 初识
目录 0. 先下载 1. 先写个轮廓 2. 牛刀小试 2.1 例子 1 2.2 例子 2 3. 模板语法 上例子 4. 文本指令 上例子 5. 属性操作 上例子 6. 样式操作 上例子 类名的操作 s ...
- Maven 项目中的groupId和artifactId
maven进行项目管理,如果我们要将项目加入到maven到本地仓库中,则需要对项目进行唯一性标示,而groupId和artifactId就起到这样对作用. groupId为项目组织对唯一标识符,可以理 ...
- Linux 创建与删除(5)
相对于Windows下的右键新建文件与删除,我更喜爱Linux下的命令式创建与删除,真的方便.不过Windows下也可以借助工具来实现,比如git bash.cmder等等终端工具. 创建文件 新建文 ...
- 在搭建Maven项目时导入elasticsearch架包时遇到的问题
<!-- 使用elasticsearch 需要导入两个包,从网上复制的可能因为有特殊字符报 cvc-complex-type.2.3: Element 'dependency' cannot h ...
- ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8;
参考来源:https://blog.csdn.net/yuxinha11/article/details/80090197 ENGINE=InnoDB不是默认就是这个引擎吗?——是的,如果不写也是ok ...