Pytorch是热门的深度学习框架之一,通过经典的MNIST 数据集进行快速的pytorch入门。

导入库

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np

准备数据集

path = './data'

# 使用Compose 将tensor化和正则化操作打包
transform_fn = Compose([
ToTensor(),
Normalize(mean=(0.1307,), std=(0.3081,))
])
mnist_dataset = MNIST(root=path, train=True, transform=transform_fn)
data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=2, shuffle=True)
# 1. 构建函数,数据集预处理
BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
def get_dataloader(train=True, batch_size=BATCH_SIZE):
'''
train=True, 获取训练集
train=False 获取测试集
'''
transform_fn = Compose([
ToTensor(),
Normalize(mean=(0.1307,), std=(0.3081,))
])
dataset = MNIST(root='./data', train=train, transform=transform_fn)
data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
return data_loader

构建模型


class MnistModel(nn.Module):
def __init__(self):
super().__init__() # 继承父类
self.fc1 = nn.Linear(1*28*28, 28) # 添加全连接层
self.fc2 = nn.Linear(28, 10) def forward(self, input):
x = input.view(-1, 1*28*28)
x = self.fc1(x)
x = F.relu(x)
out = self.fc2(x)
return F.log_softmax(out, dim=-1) # log_softmax 与 nll_loss合用,计算交叉熵

模型训练

mnist_model = MnistModel()
optimizer = torch.optim.Adam(params=mnist_model.parameters(), lr=0.001) # 如果有模型则加载
if os.path.exists('./model'):
mnist_model.load_state_dict(torch.load('model/mnist_model.pkl'))
optimizer.load_state_dict(torch.load('model/optimizer.pkl'))
def train(epoch):
data_loader = get_dataloader() for index, (data, target) in enumerate(data_loader):
optimizer.zero_grad() # 梯度先清零
output = mnist_model(data)
loss = F.nll_loss(output, target)
loss.backward() # 误差反向传播计算
optimizer.step() # 更新梯度 if index % 100 == 0:
# 保存训练模型
torch.save(mnist_model.state_dict(), 'model/mnist_model.pkl')
torch.save(optimizer.state_dict(), 'model/optimizer.pkl')
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, index * len(data), len(data_loader.dataset),
100. * index / len(data_loader), loss.item()))
for i in range(epoch=5):
train(i)
Train Epoch: 0 [0/60000 (0%)]	Loss: 0.023078
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.019347
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.105870
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.050866
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.097995
Train Epoch: 1 [0/60000 (0%)] Loss: 0.108337
Train Epoch: 1 [12800/60000 (21%)] Loss: 0.071196
Train Epoch: 1 [25600/60000 (43%)] Loss: 0.022856
Train Epoch: 1 [38400/60000 (64%)] Loss: 0.028392
Train Epoch: 1 [51200/60000 (85%)] Loss: 0.070508
Train Epoch: 2 [0/60000 (0%)] Loss: 0.037416
Train Epoch: 2 [12800/60000 (21%)] Loss: 0.075977
Train Epoch: 2 [25600/60000 (43%)] Loss: 0.024356
Train Epoch: 2 [38400/60000 (64%)] Loss: 0.042203
Train Epoch: 2 [51200/60000 (85%)] Loss: 0.020883
Train Epoch: 3 [0/60000 (0%)] Loss: 0.023487
Train Epoch: 3 [12800/60000 (21%)] Loss: 0.024403
Train Epoch: 3 [25600/60000 (43%)] Loss: 0.073619
Train Epoch: 3 [38400/60000 (64%)] Loss: 0.074042
Train Epoch: 3 [51200/60000 (85%)] Loss: 0.036283
Train Epoch: 4 [0/60000 (0%)] Loss: 0.021305
Train Epoch: 4 [12800/60000 (21%)] Loss: 0.062750
Train Epoch: 4 [25600/60000 (43%)] Loss: 0.016911
Train Epoch: 4 [38400/60000 (64%)] Loss: 0.039599
Train Epoch: 4 [51200/60000 (85%)] Loss: 0.026689

模型测试

def test():
loss_list = []
acc_list = [] test_loader = get_dataloader(train=False, batch_size = TEST_BATCH_SIZE)
mnist_model.eval() # 设为评估模式 for index, (data, target) in enumerate(test_loader):
with torch.no_grad():
out = mnist_model(data)
loss = F.nll_loss(out, target)
loss_list.append(loss) pred = out.data.max(1)[1]
acc = pred.eq(target).float().mean() # eq()函数用于将两个tensor中的元素对比,返回布尔值
acc_list.append(acc) print('平均准确率, 平均损失', np.mean(acc_list), np.mean(loss_list))
test()
平均准确率, 平均损失 0.9662777 0.12309619

Pytorch实现MNIST手写数字识别的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  3. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  4. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  5. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  6. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  7. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

  8. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  9. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

随机推荐

  1. 洛谷 P5176 公约数 题解

    原题链接 我天哪 大大的庆祝一下: 数论黑题 \(T1\) 达成! 激动地不行 记住套路:乱推 \(\gcd\),欧拉筛模板,然后乱换元,乱换式子,完了整除分块,欧拉筛和前缀和就解决了! \[\sum ...

  2. WeChat 搭建过程

    [被动回复消息] 1.创建项目(基于MyEclipse + Tomcat 7 编写):wechat 2.导入jar包(用于解析xml):dom4j-1.6.1.jar,xstream-1.3.jar ...

  3. OpenCV-Python ORB(面向快速和旋转的BRIEF) | 四十三

    目标 在本章中,我们将了解ORB的基础知识 理论 作为OpenCV的狂热者,关于ORB的最重要的事情是它来自" OpenCV Labs".该算法由Ethan Rublee,Vinc ...

  4. 医学图像 | 使用深度学习实现乳腺癌分类(附python演练)

    乳腺癌是全球第二常见的女性癌症.2012年,它占所有新癌症病例的12%,占所有女性癌症病例的25%. 当乳腺细胞生长失控时,乳腺癌就开始了.这些细胞通常形成一个肿瘤,通常可以在x光片上直接看到或感觉到 ...

  5. K8S 资源收集和展示 top & DashBoard-UI

    一.前言 在近期的 K8S 开发调试的过程中,总会想知道 Node 或者 Pod 的更多信息.但 $ kubectl top node $ kubectl top pod 中的 top 操作符,需要 ...

  6. A 大地魂力

    时间限制 : - MS   空间限制 : - KB  评测说明 : 1s,256m 问题描述 奶牛贝西认为,要改变世界,就必须吸收大地的力量,贝西把大地的力量称为魂力.要吸取大地的魂力就需要在地上开出 ...

  7. hacknos-player靶机渗透

    靶机下载地址https://www.vulnhub.com/entry/hacknos-player,459/ 网络配置 该靶机可能会存在无法自动分配IP的情况,所以无法扫描到的情况下需要手动配置获取 ...

  8. 使用 python 进行身份证号校验

    使用 python 代码进行身份证号校验 先说,还有很多可以优化的地方. 1.比如加入15位身份证号的校验,嗯哼,15位的好像没有校验,那就只能提取个出生年月日啥的了. 2.比如判断加入地址数据库,增 ...

  9. Android 启动一个Activity的几种方式

    启动一个Activity的几种方式在Android中我们可以通过下面两种方式来启动一个新的Activity,注意这里是怎么启动,而非启动模式!分为显示启动和隐式启动! 1.显式启动,通过包名来启动,写 ...

  10. MTK Android MCC(移动国家码)和 MNC(移动网络码)

    国际移动用户识别码(IMSI) international mobile subscriber identity 国际上为唯一识别一个移动用户所分配的号码. 从技术上讲,IMSI可以彻底解决国际漫游问 ...