都说MNIST相当于机器学习界的Hello World。最近加入实验室,导师给我们安排了一个任务,但是我才刚刚入门呐!!没办法,只能从最基本的学起。

  Pytorch是一套开源的深度学习张量库。或者我倾向于把它当成一个独立的深度学习框架。为了写这么一个"Hello World"。查阅了不少资料,也踩了不少坑。不过同时也学习了不少东西,下面我把我的代码记录下来,希望能够从中受益更多,同时帮助其他对Pytorch感兴趣的人。代码的注释中有不对的地方欢迎批评指正。

  代码进行了注释,应该很方便阅读。 dependences: numpy torch torchvision python3 使用pip安装即可。

 # encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F #加载nn中的功能函数
import torch.optim as optim #加载优化器有关包
import torch.utils.data as Data
from torchvision import datasets,transforms #加载计算机视觉有关包
from torch.autograd import Variable BATCH_SIZE = 64 #加载torchvision包内内置的MNIST数据集 这里涉及到transform:将图片转化成torchtensor
train_dataset = datasets.MNIST(root='~/data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='~/data/',train=False,transform=transforms.ToTensor()) #加载小批次数据,即将MNIST数据集中的data分成每组batch_size的小块,shuffle指定是否随机读取
train_loader = Data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False) #定义网络模型亦即Net 这里定义一个简单的全连接层784->10
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear1 = nn.Linear(784,10) def forward(self,X):
return F.relu(self.linear1(X)) model = Model() #实例化全连接层
loss = nn.CrossEntropyLoss() #损失函数选择,交叉熵函数
optimizer = optim.SGD(model.parameters(),lr = 0.1)
num_epochs = 5 #以下四个列表是为了可视化(暂未实现)
losses = []
acces = []
eval_losses = []
eval_acces = [] for echo in range(num_epochs):
train_loss = 0 #定义训练损失
train_acc = 0 #定义训练准确度
model.train() #将网络转化为训练模式
for i,(X,label) in enumerate(train_loader): #使用枚举函数遍历train_loader
X = X.view(-1,784) #X:[64,1,28,28] -> [64,784]将X向量展平
X = Variable(X) #包装tensor用于自动求梯度
label = Variable(label)
out = model(X) #正向传播
lossvalue = loss(out,label) #求损失值
optimizer.zero_grad() #优化器梯度归零
lossvalue.backward() #反向转播,刷新梯度值
optimizer.step() #优化器运行一步,注意optimizer搜集的是model的参数 #计算损失
train_loss += float(lossvalue)
#计算精确度
_,pred = out.max(1)
num_correct = (pred == label).sum()
acc = int(num_correct) / X.shape[0]
train_acc += acc losses.append(train_loss / len(train_loader))
acces.append(train_acc / len(train_loader))
print("echo:"+' ' +str(echo))
print("lose:" + ' ' + str(train_loss / len(train_loader)))
print("accuracy:" + ' '+str(train_acc / len(train_loader)))
eval_loss = 0
eval_acc = 0
model.eval() #模型转化为评估模式
for X,label in test_loader:
X = X.view(-1,784)
X = Variable(X)
label = Variable(label)
testout = model(X)
testloss = loss(testout,label)
eval_loss += float(testloss) _,pred = testout.max(1)
num_correct = (pred == label).sum()
acc = int(num_correct) / X.shape[0]
eval_acc += acc eval_losses.append(eval_loss / len(test_loader))
eval_acces.append(eval_acc / len(test_loader))
print("testlose: " + str(eval_loss/len(test_loader)))
print("testaccuracy:" + str(eval_acc/len(test_loader)) + '\n')

运行后的结果如下:

  我们在上面的代码中,将图片对应的Pytorchtensor展平,并通过一个全连接层,仅仅是这样就达到了90%以上的准确率。如果使用卷积层,正确率有望达到更高。

  代码并不完备,还可以增加visualize和predict功能,等我学到更多知识后,有待后续添加。  

PytorchMNIST(使用Pytorch进行MNIST字符集识别任务)的更多相关文章

  1. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

  2. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  3. 深度学习-mnist手写体识别

    mnist手写体识别 Mnist数据集可以从官网下载,网址: http://yann.lecun.com/exdb/mnist/ 下载下来的数据集被分成两部分:55000行的训练数据集(mnist.t ...

  4. Pytorch实现MNIST手写数字识别

    Pytorch是热门的深度学习框架之一,通过经典的MNIST 数据集进行快速的pytorch入门. 导入库 from torchvision.datasets import MNIST from to ...

  5. Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)

    目录 1. 准备数据集 1.1 MNIST数据集获取: 1.2 程序部分 2. 设计网络结构 2.1 网络设计 2.2 程序部分 3. 迭代训练 4. 测试集预测部分 5. 全部代码 1. 准备数据集 ...

  6. 基于PyTorch实现MNIST手写字识别

    本篇不涉及模型原理,只是分享下代码.想要了解模型原理的可以去看网上很多大牛的博客. 目前代码实现了CNN和LSTM两个网络,整个代码分为四部分: Config:项目中涉及的参数: CNN:卷积神经网络 ...

  7. 全网最详细最好懂 PyTorch CNN案例分析 识别手写数字

    先来看一下这是什么任务.就是给你手写数组的图片,然后识别这是什么数字: dataset 首先先来看PyTorch的dataset类: 我已经在从零学习pytorch 第2课 Dataset类讲解了什么 ...

  8. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

  9. CNN算法解决MNIST数据集识别问题

    网络实现程序如下 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 用于设置将记 ...

随机推荐

  1. CSS3新子代选择器

    :nth-child(n) 选择器匹配属于其父元素的第 N 个子元素,不论元素的类型,除了<h>标签. n 可以是数字.关键词或公式 例子一 <!DOCTYPE html> & ...

  2. 001_C语言中运算符的优先级

    总的来说就是: 1. 最高:单目运算符(() > * 解引用,&取地址,-取相反数,++等自增(或减)运算,!取反运算...); 2. 次之:双目运算符(算数运算符 > 移位运算符 ...

  3. Kubernetes学习笔记(六):使用ConfigMap和Secret配置应用程序

    概述 本文的核心是:如何处理应用程序的数据配置. 配置应用程序可以使用以下几种途径: 向容器传递命令行参数 为每个容器配置环境变量 通过特殊的卷将配置文件挂载到容器中 向容器传递命令行参数 在Kube ...

  4. Cypress系列(2)- Cypress 框架的详细介绍

    如果想从头学起Cypress,可以看下面的系列文章哦 https://www.cnblogs.com/poloyy/category/1768839.html Cypress 简介 基于 JavaSc ...

  5. 使用 IdentityService4 集成小程序登录一种尝试

    1 场景介绍 主要业务是通过 App 承载,在 App 中可以注册和登录,为了更好的发展业务引入了微信小程序,于是如何让这两个入口的用户互通便成了需要解决的问题. 看了一下其它 App 大致地思路是两 ...

  6. 上古神器vim系列之移动三板斧

    [导读] 前文总结了vim如何进入,如何保存退出,如何进入编辑模式.本文来总结一些稍微进阶的内容,在normal模式下如何高效的浏览代码. 模式回顾 在normal模式下主要用于浏览代码,那么有哪些方 ...

  7. [Chrome插件开发]001.入门

    Chrome插件开发入门 Chrome扩展文件 Browser Actions(扩展图标) Page Actions(地址栏图标) popup弹出窗口 Background Pages后台页面 实战讲 ...

  8. MySQL不香吗,清华架构师告诉你为什么还要有noSQL?

    强烈推荐观看: 阿里P8架构师谈(数据库系列):NoSQL使用场景和选型比较,以及与SQL的区别_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili​www.bilibili.com noSQL的大概 ...

  9. Chisel3 - bind - Wire, Reg, MemPort

    https://mp.weixin.qq.com/s/AxYlRtAXjd55eoGX5l1W-A   模块(Module)从输入端口(input ports)接收输入,经过内部实现的转换逻辑,从输出 ...

  10. InnoSetup汉化版打包工具下载-附带脚本模板

    InnoSetup汉化版打包工具下载地址: https://www.90pan.com/b1907264 脚本模板 ; 脚本用 Inno Setup 脚本向导 生成.; 查阅文档获取创建 INNO S ...