#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03' import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data train_data = torchvision.datasets.MNIST(
'./mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'./mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
) def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out model = Net()
print(model) optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss() for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data)))) # evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))

Pytorch入门实例:mnist分类训练的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  4. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. Pytorch入门中 —— 搭建网络模型

    本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...

  8. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

  9. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

随机推荐

  1. Linux多线程编程,为什么要使用线程,使用线程的理由和优点等

    线程?为什么有了进程还需要线程呢,他们有什么区别?使用线程有什么优势呢?还有多线程编程的一些细节问题,(http://www.0830120.com)如线程之间怎样同步.互斥,这些东西将在本文中介绍. ...

  2. K-means算法性能评估及其优化

    1. SSE误差平方和(Sum of Square due to Error): 聚类情况: 计算公式: 注:SSE参数计算的内容为当前迭代得到的中心位置到各自中心点簇的欧式距离总和,这个值越小表示当 ...

  3. Tencent Cloud 腾讯云上部署 EMR Cluster + Kafka + Confluent (Schema-Registry)

    腾讯云上有些操作比起 Amazon AWS 还是很方便的, 尤其部署EMR Cluster,下面详细介绍步骤:

  4. AddIn 中当前完整文件名的获取

    Me.Application.ActiveWorkbook.Name 需要注意的是:只有当前文件已经存档的情况下,才能获得后缀名.

  5. Jmeter中实现base64加密

    Jmeter已不再提供内置base64加密函数,遇到base64加密需求,需要通过beanshell实现 直接上beanshell代码: import org.apache.commons.net.u ...

  6. 【2019雅礼集训】【最大费用流】【模型转换】D2T3 sum

    目录 题意 输入格式 输出格式 思路 代码 题意 现在你有一个集合{1,2,3,...,n},要求你从中取出一些元素,使得这些元素两两互质.问你能够取出的元素总和最多是多少? 输入格式 一个整数n 输 ...

  7. Kafka详细配置

    转自:http://blog.csdn.net/suifeng3051/article/details/38321043?utm_source=tuicool&utm_medium=refer ...

  8. 【尺取法】Jurisdiction Disenchantment

    [尺取法]Jurisdiction Disenchantment PROBLEM 时间限制: 1 Sec 内存限制: 128 MB 题目描述 The Super League of Paragons ...

  9. [LeetCode] Magic Squares In Grid 网格中的神奇正方形

    A 3 x 3 magic square is a 3 x 3 grid filled with distinct numbers from 1 to 9 such that each row, co ...

  10. vue项目开发时怎么解决跨域

    vue项目中,前端与后台进行数据请求或者提交的时候,如果后台没有设置跨域,前端本地调试代码的时候就会报“No 'Access-Control-Allow-Origin' header is prese ...