#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03' import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data train_data = torchvision.datasets.MNIST(
'./mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'./mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
) def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out model = Net()
print(model) optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss() for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data)))) # evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))

Pytorch入门实例:mnist分类训练的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  4. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. Pytorch入门中 —— 搭建网络模型

    本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...

  8. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

  9. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

随机推荐

  1. 异步简析之BlockingCollection实现生产消费模式

    目前市面上有诸多的产品实现队列功能,比如Redis.MemCache等... 其实c#中也有一个基础的集合类专门用来实现生产/消费模式 (生产模式还是建议使用Redis等产品) 下面是官方的一些资料和 ...

  2. Linux命令学习总结之rmdir命令的相关资料可以参考下

    这篇文章主要介绍了Linux命令学习总结之rmdir命令的相关资料,需要的朋友可以参考下(http://www.nanke0834.com) 命令简介: rmdir命令用用来删除空目录,如果目录非空, ...

  3. SpringBoot-异常问题总结

    一:创建的SpringBoot项目之后测试访问接口报错: Whitelabel Error Page This application has no explicit mapping for /err ...

  4. vue笔记-条件渲染

    条件渲染 1:指令v-if单独使用和结合v-else //单独使用 <h1 v-if="ok">Yes</h1> //组合使用 <h1 v-if=&q ...

  5. 基于Cmake+QT+VS的C++项目构建开发编译简明教程

    目录 一.工具下载与安装 1.     Qt 2.     Visual Studio 2015 3.     Cmake 二.C++及Qt项目构建 1.     基于VS构建Qt项目 2.     ...

  6. JAVA基础复习与总结<四> 抽象类与接口

    抽象类(Abstract Class) 是一种模版模式.抽象类为所有子类提供了一个通用模版,子类可以在这个模版基础上进行扩展.通过抽象类,可以避免子类设计的随意性.通过抽象类,我们就可以做到严格限制子 ...

  7. python-MYSQL(包括ORM)交互

    1.首先,我们必须得连上我们的MYSQL数据库.个人遇到连不上MYSQL数据的问题主要有:数据库的权限问题.数据库表权限的问题 同时获取数据库中的数据等. //==================== ...

  8. git 提交代码到库

    今天用git commit -m “注释”提交的时候,注释写错了,于是各种查资料开始了和git bash vim的纠缠...(网上的资料我真是没操作成功,不过最后还是摸索出来了) 首先 使用 git ...

  9. 图解Raft之日志复制

    日志复制可以说是Raft集群的核心之一,保证了Raft数据的一致性,下面通过几张图片介绍Raft集群中日志复制的逻辑与流程: 在一个Raft集群中只有Leader节点能够接受客户端的请求,由Leade ...

  10. SSM 框架搭建

    SSM框架搭建(Spring.SpringMVC.Mybatis) 一:基本概念 Spring :      Spring是一个开源框架,Spring是于2003 年兴起的一个轻量级的Java 开发框 ...