深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别
开始学习深度学习,先来一个手写数字的程序
import numpy as np
import os
import codecs
import torch
from PIL import Image
lr = 0.01
momentum = 0.5
epochs = 10
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)
def loadmnist(path, kind='train'):
labels_path = os.path.join(path, 'mnist' ,'%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path,'mnist' ,'%s-images.idx3-ubyte' % kind)
labels = read_label_file(labels_path)
images = read_image_file(images_path)
return images, labels
import torch.utils.data as data
import torchvision.transforms as transforms
class Loader(data.Dataset):
def __init__(self, root, label, transforms):
self.imgs = []
imgs,labels = loadmnist(root, label)
self.imgs = imgs
self.labels = labels
self.transforms = transforms
def __getitem__(self, index):
img, label = self.imgs[index],self.labels[index]
img = Image.fromarray(img.numpy(), mode='L')
if self.transforms:
img = self.transforms(img)
return img, label
def __len__(self):
return len(self.imgs)
def getTrainDataset():
return Loader('d:\\work\\yoho\\dl\\dl-study\\chapter0\\', 'train', transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]))
def getTestDataset():
return Loader('d:\\work\\yoho\\dl\\dl-study\\chapter0\\', 't10k', transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]))
import torch as t
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(inplace=True),
nn.Conv2d(10, 20, kernel_size=5),
nn.Dropout2d(),
nn.MaxPool2d(2),
nn.ReLU(inplace=True),
)
self.classifier = nn.Sequential(
nn.Linear(320, 50),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(50, 10),
nn.LogSoftmax(dim=1)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
net = Net()
import torch.optim as optim
from torch.nn.modules import loss
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
criterion = loss.CrossEntropyLoss()
train_dataset = getTrainDataset()
test_dataset = getTestDataset()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)
from torch.autograd import Variable as V
def train(epoch):
for i, (inputs, labels) in enumerate(train_loader):
inputs_var, labels_var = V(inputs), V(labels)
outputs = net(inputs_var)
losses = criterion(outputs, labels_var)
optimizer.zero_grad()
losses.backward()
optimizer.step()
def test(epoch):
for i, (inputs, labels) in enumerate(test_loader):
inputs_var = V(inputs)
outputs = net(inputs_var)
_, pred = outputs.data.topk(5, 1, True, True)
batch_size = labels.size(0)
pred = pred.t()
corrent = pred.eq(labels.view(1, -1).expand_as(pred))
res = []
for k in (1,5):
correct_k = corrent[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
print('{} {} top1 {} top5 {}'.format(epoch, i ,res[0][0], res[1][0]))
def main():
for epoch in range(0, epochs):
train(epoch)
test(epoch)
main()
学习之后的,正确率很高,这种问题对于深度学习已经解决了。
深度学习之 mnist 手写数字识别的更多相关文章
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
随机推荐
- webpack深入场景——开发环境和生产环境配置
以前自己写一小项目时,webpack的配置基本就是一套配置,没有考虑生产环境和开发环境的区分,最近在做一个复杂的商城项目接触到了webpack的高级配置,经过两天的研究,写出了一份目前来说比叫满意的配 ...
- java——对象学习笔记
1.面向对象(OOP)的三大特性 对象的行为(behavior):可以对对象施加哪些操作,或者可以对对象施加哪些方法. 对象的状态(state):当施加那些方法后,对象如何响应. 对象标识(ident ...
- Vuejs实例-使用vue-cli创建项目
1,首先从官方网站下载安装Node.js,建议使用6.x版本,同时也会一并安装npm工具,npm>3.10以上. 2,npm安装很慢(国外服务器),所以一般推荐使用npm淘宝镜像cnpm,先安装 ...
- Activiti就是这么简单
Activiti介绍 什么是Activiti? Activiti5是由Alfresco软件在2010年5月17日发布的业务流程管理(BPM)框架,它是覆盖了业务流程管理.工作流.服务协作等领域的一个开 ...
- pdf文件中截取eps图片并压缩
最近遇到了一个问题,需要从pdf裁剪出其中部分的矢量图格式的图片,并保存为eps格式,方便使用. 最简单的方法就是先用acrobat pro将pdf进行页面抽取,并裁剪,剩下所需要的图片部分,然后另存 ...
- Kon-boot v2.5介绍与使用方法总结(支持win10)
Kon-boot这个工具相信大家都不陌生,这是一款专门针对Windows.Linux.MAC登陆密码破解工具,他能绕过系统所设有的登陆密码,让你的登陆畅通无阻.KON-Boot的原理是在于处理BIOS ...
- cmd 命令大全
1.windows 系统定时关机 定时关机:shutdown -s -t 300 at 18:30 shutdown -s 取消定时:shutdown -a 注意:300为秒数,在windows co ...
- phpStorm安装方法
1)下载 http://big2.h5gamen.com/soft/jetbrainscrack-2.6.2.zip 放到phpstorm安装目录下的lib文件夹 如放到f盘 F:\PhpStorm ...
- node命令curl
一.打开另一个命令行窗口,运行下面的命令. curl -X POST --data "name=Jack" 127.0.0.1:3000 上面代码使用 POST 方法向服务器发送一 ...
- 兼容的Ajax
/** * 创建XMLHttpRequest对象 * @param _method 请求方式: post||get * @param _url 远程服务器地址 * @param _async 是否异步 ...