Pytorch卷积神经网络识别手写数字集
卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络充满好奇心,这里为你带来pytorch实现cnn一些入门的教程代码
#首先导入包
import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torch.utils.data as Data
#一、数据准备
#训练数据:用了torchvision.datasets.MNIST,root是文件路径,train为True(这是训练数据),transform是把图像数据转换为张量,download(如果本地已有该文件选择false,没有就选择true)
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=False)
#训练数据:同上,train为False(这是测试数据)
test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)
# "训练数据加载器":dataset为训练数据,shuflle为打乱数据的顺序,batch_size是让数据50个为一组
train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=50)
test_data.test_data.size()
torch.Size([10000, 28, 28])
#测试数据 test_data下的test_data为测试数据,因为下面conv2d输入的为4维数据,所以此处用torch.unsqueeze升维
test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)
#测试数据目标值
test_y = test_data.test_labels
#二、实现模型
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
#conv2d参数:输入1维,输出16维,5个卷积核(kernel),步长(stride)为1,padding是2(如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1)
self.conv1 = nn.Sequential(nn.Conv2d(1,16,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
#Linear参数:输入维数,输出分的种类数
self.out = nn.Linear(32*7*7,10)
def forward(self,x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
#这里给x3降为2维可以让linear函数使用
x3 = x2.view(x2.size(0),-1)
out = self.out(x3)
return out
#自动调整参数,最优化模型
cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.02)
loss_func = nn.CrossEntropyLoss()
#三、训练模型
for step,(x,y) in enumerate(train_loader):
x = Variable(x)
y = Variable(y)
out = cnn(x)
loss = loss_func(out,y)
#以下为固定操作,为了训练每一条数据,不断调整参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
#四、测试
predict = cnn(test_x[:10])
res = torch.max(predict,1)[1]
res #测试数据
tensor([7, 2, 1, 0, 4, 1, 4, 9, 9, 9])
test_y[:10] #真实数据
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
#在这里我们发现前十个数据分类准确率达到90
Pytorch卷积神经网络识别手写数字集的更多相关文章
- 如何用卷积神经网络CNN识别手写数字集?
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
- Tensorflow搭建卷积神经网络识别手写英语字母
更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...
- PyTorch基础——使用卷积神经网络识别手写数字
一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...
- Python实现神经网络算法识别手写数字集
最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...
- 使用TensorFlow的卷积神经网络识别手写数字(3)-识别篇
from PIL import Image import numpy as np import tensorflow as tf import time bShowAccuracy = True # ...
- 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇
import numpy as np import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import ...
- 使用TensorFlow的卷积神经网络识别手写数字(1)-预处理篇
功能: 将文件夹下的20*20像素黑白图片,根据重心位置绘制到28*28图片上,然后保存.经过预处理的图片有利于数字的准确识别.参见MNIST对图片的要求. 此处可下载已处理好的图片: https:/ ...
- 李宏毅 Keras手写数字集识别(优化篇)
在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
随机推荐
- MySQL里默认的几个库是干啥的?
本文涉及:MySQL安装后自带的4个数据库:information_schema. performance_schema.sys.mysql的作用及其中各个表所存储的数据含义 information_ ...
- JS this指向总结
使用 JavaScript 开发的时候,很多开发者多多少少会被 this 的指向搞蒙圈,但是实际上,关于 this 的指向,记住最核心的一句话:哪个对象调用函数,函数里面的this指向哪个对象. 下面 ...
- 逻辑回归 之 Logist 推导
Logist从概率角度认识 可以咱学校教材大二版的<> - 山大版, 来整一波, 为了简化推导形式呢, 这里就假设2个样本空间的形式来展开, 基于(条件概率) 全概率与贝叶斯 作为核心. ...
- ML-线性 SVM 推导
Max Margin svm 即Suport Vector Machine, 中文意为:支持向量机. 对于二分类问题, 在样本空间中(即便是多维向量, 在空间中可表示为一个点). svm的核心思想就是 ...
- Linux系统的时间比北京时间慢12个小时的处理方案(将EDT时区改为CST)
今天查看Linux操作系统的时间,发现比正常时间慢12个小时整,感觉很奇怪,后来使用ntp服务器校对时间发现也是不管用的,还是慢12个小时.之前遇到过是慢8个小时,但是我知道是因为使用的是UTC时间, ...
- Beta冲刺第6次
二.Scrum部分 1. 各成员情况 翟仕佶 学号:201731103226 今日进展 新增图像拼接合并功能 存在问题 无 明日安排 视情况而定 截图 曾中杰 学号:201731062517 今日进展 ...
- oracle 字符集安装错了,修改字符集 及创建用户 表空间 ,删除用户及所有的表
1.首先以sysdba的身份登录上去 conn /as sysdba 2.关闭数据库shutdown immediate; 3.以mount打来数据库,startup mount 4.设置sessio ...
- python的sort和sorted
sort 只适用于列表,返回列表类型. sorted 可适用于字典,元组和列表. 使用方法 sort的使用方法是list.sort(cmp=None, key=None, reverse=False) ...
- php单例型(singleton pattern)
搞定,吃饭 <?php /* The purpose of singleton pattern is to restrict instantiation of class to a single ...
- OAuth 第三方登录授权码(authorization code)方式的小例子
假如上面的网站A,可以通过GitHub账号登录: 下面以OAuth其中一种方式,授权码(authorization code)方式为例. 一.第三方登录的原理 所谓第三方登录,实质就是 OAuth 授 ...