#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03' import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data train_data = torchvision.datasets.MNIST(
'./mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'./mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
) def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out model = Net()
print(model) optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss() for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data)))) # evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))

Pytorch入门实例:mnist分类训练的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  4. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. Pytorch入门中 —— 搭建网络模型

    本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...

  8. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

  9. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

随机推荐

  1. 【MSVC】_invalid_parameter和_invoke_watson

    最近遇到一个崩溃问题,崩溃在CRT中,最后调用的函数是_invoke_watson.调用关系大概如下: _invoke_watson _invalid_parameter _invalid_param ...

  2. 基于docker/dockerfile实现redis主从复制

    今天我们来搭建基于docker实现redis主从复制集群 为什么要使用redis集群模式? Redis可以说是内存数据库,mysql的数据库是真实存储在硬盘里的,因此,redis的读取速度要比mysq ...

  3. Java - 数组详解(图解数组的基本操作)

    目录 什么是数组 数组的定义和内存分配 数组的赋值和访问 数组的注意事项 数组的内存图解 数组的插入 数组的删除 数组的扩容 数组的反转 首先 什么是数组 数组是一组地址连续.长度固定的具有相同类型的 ...

  4. SpringBoot的自动配置原理

    一.入口 上篇注解@SpringBootApplication简单分析,说到了@SpringBootApplication注解的内部结构, 其中@EnableAutoConfiguration利用En ...

  5. nginx配置ssl验证

    在上一篇博客中,我们只是通过nginx搭建了反向代理服务,由于需要在小程序中使用https服务,所以需要申请安全证书. 1.在所购买的域名商那申请免费的ssl证书,我买的是阿里的,所以直接在阿里上申请 ...

  6. 1.3 正则表达式和python语言-1.3.6匹配多个字符串

    1.3.6 匹配多个字符串(2018-05-08) 我们在正则表达式 bat|bet|bit 中使用了择一匹配(|)符号.如下为在 Python中使用正则表达式的方法. import re #bat| ...

  7. 20181115 python-第一章学习小结part3

    第一章,基本数据类型-------仅学三种,字符型,数字型,布尔型 仅学三种数据类型: 字符型,加了引号的都可以被认为是字符串,字符串可以拼接 数字型,int,float,long三种,可以进行运算 ...

  8. Java 什么是线程安全

    当多个线程访问某个类时,不管运行时环境采用何种调度方式或者这些线程将如何交替执行,并且在主调代码中不需要额外的同步或协同,这个类都能表现出正确的行为,那么这个类就是线程安全的.其中,正确性指某个类的行 ...

  9. python学习:列表

    列表 a = ['abc','bcd','cde','def','efg']print(a)列表的操作:增删改查 1)查:切片print(a[1:3]) #从'bcd'取到'cde',列表取值顾头不顾 ...

  10. Windows下编译jcef

    依赖软件参考 本文参考官方网站上的jcef编译过程 编译成功的环境如下: windows 10 64 bit JDK 1.8.0_121 64 bit Python 2.7.13 git versio ...