Pytorch入门实例:mnist分类训练
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03' import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data train_data = torchvision.datasets.MNIST(
'./mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'./mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
) def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out model = Net()
print(model) optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss() for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data)))) # evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))
Pytorch入门实例:mnist分类训练的更多相关文章
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)
Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
- Pytorch入门之VAE
关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门下 —— 其他
本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...
随机推荐
- 4.BN推导
参考博客:https://www.cnblogs.com/guoyaohua/p/8724433.html 参考知乎:https://www.zhihu.com/question/38102762/a ...
- Tag Helpers 的使用介绍
什么是 Tag Helpers ? 在 Razor 文件中,Tag Helpers 能够让服务端代码参与创建和渲染 HTML 元素.例如,内置的ImageTagHelper能够在图像名称后面追加版本号 ...
- 解决Spring boot中读取属性配置文件出现中文乱码的问题
问题描述: 在配置文件application.properties中写了 server.port=8081 server.servlet.context-path=/boy name=张三 age=2 ...
- MySql分割字符串【存储过程】
MYSql没有表变量,通过函数无法返回表. 参考网址:https://bbs.csdn.net/topics/330021055 DELIMITER $$ USE `数据库`$$ DROP PROCE ...
- 移动UI框架
---恢复内容开始--- 一,框架使用selenium+appium+po+unittest+python 1.其中po表示居于page of boject的思想,unittest是单元测试框架 2. ...
- CrypMic分析报告
一.概述 病毒伪装为NirSoft公司的软件NirCmd并加了MPRESS壳,脱壳后是一个混淆过的PE程序,运行时会用到类似PE映像切换的方式来释放出实际的恶意代码,恶意代码主要对文件进行加密. 二. ...
- pyhton 监听文件输入实例
def tail(filename): f = open(filename,encoding='utf-8') while True: line = f.readline() if line.stri ...
- ipset和iptables配合来自动封闭和解封有问题的IP
iptables封掉少量ip处理是没什么问题的,但是当有大量ip攻击的时候性能就跟不上了,iptables是O(N)的性能.而ipset就像一个集合,把需要封闭的ip地址放入这个集合中,ipset 是 ...
- 32 ArcToolBox学习系列之数据管理工具箱——属性域(Domains)的两种创建及使用方式
属性域分为两类,一种是范围域,一种是编码的值,下面将两个一起介绍,其中涉及到的编码,名称,只是试验,并非真实情况. 一.首先新建一个文件型地理数据库,将数据导入或者是新建要素类都可以 二.打开ArcT ...
- muse-ui底部导航自定义图标和字体颜色
最近在鼓捣用vue.js进行混合APP开发,遍寻许久终于找到muse-ui这款支持vue的轻量级UI框架,竟还支持按需引入,甚合萝卜意! 底部导航的点击波纹特效也是让我无比惊喜,然而自定义图标和字体颜 ...