1.使用RNN做MNIST分类
第一次用LSTM,从简单做起吧~~
注意事项:
- batch_first=True 意味着输入的格式为(batch_size,time_step,input_size),False 意味着输入的格式为(time_step,batch_size,input_size)
- 取r_out[:,-1,:],即取时间步最后一步的结果,相当于LSTM把一张图片全部扫描完后的返回的状态向量(此时的维度变为(64,64),前面的64是batch_size,后面的64是隐藏层的神经元个数)
import torch
from torch.autograd import Variable
from torchvision import datasets,transforms
#超参数
EPOCH=1
BATCH_SIZE=64
TIME_STEP=28#run time step/image height
INPUT_SIZE=28#run input size/image width
LR=0.01
DOWNLOAD_MNIST=True train_data=datasets.MNIST(root='./mnist',train=True,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
train_loader=torch.utils.data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True) test_data=datasets.MNIST(root='./mnist',train=False,transform=transforms.ToTensor(),download=DOWNLOAD_MNIST)
test_loader=torch.utils.data.DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True) class RNN(torch.nn.Module):
def __init__(self):
super(RNN,self).__init__() self.rnn=torch.nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1, batch_first=True,
)
self.out=torch.nn.Linear(64,10)
def forward(self, x):
r_out,(h_n,h_c)=self.rnn(x,None)#[64,28,64]
out=self.out(r_out[:,-1,:])#[64,10]
return out #time_step,batch,input batch_first=False,
rnn=RNN()
print(rnn) optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=torch.nn.CrossEntropyLoss() for epoch in range(EPOCH):
for step,(x,y) in enumerate(train_loader):
b_x=Variable(x.view(-1,28,28))#reshape x to (batch,time_step.input_size) b_y=Variable(y).squeeze()
output=rnn(b_x)
loss=loss_func(output,b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step() if step %50==0:
for test_x,test_y in test_loader:
test_output=rnn(test_x.view(-1,28,28))
pred_y=torch.max(test_output,1)[1].data.numpy().squeeze()
test_y=test_y.numpy()
acc=sum(pred_y==test_y)/test_y.size
print(acc)
1.使用RNN做MNIST分类的更多相关文章
- 芝麻HTTP:TensorFlow LSTM MNIST分类
本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间. 初始化 首先我们可以先初始化一些变量,如学习率.节点单元数 ...
- 2.使用RNN做诗歌生成
诗歌生成比分类问题要稍微麻烦一些,而且第一次使用RNN做文本方面的问题,还是有很多概念性的东西~~~ 数据下载: 链接:https://pan.baidu.com/s/1uCDup7U5rGuIlIb ...
- 使用CNN做文本分类——将图像2维卷积换成1维
使用CNN做文本分类 from __future__ import division, print_function, absolute_import import tensorflow as tf ...
- 《机器学习系统设计》之应用scikit-learn做文本分类(上)
前言: 本系列是在作者学习<机器学习系统设计>([美] WilliRichert)过程中的思考与实践,全书通过Python从数据处理.到特征project,再到模型选择,把机器学习解决这个 ...
- Tensorflow实战第十课(RNN MNIST分类)
设置RNN的参数 我们本节采用RNN来进行分类的训练(classifiction).会继续使用手写数据集MNIST. 让RNN从每张图片的第一行像素读到最后一行,然后进行分类判断.接下来我们导入MNI ...
- 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)
1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...
- TensorFlow入门(三)多层 CNNs 实现 mnist分类
欢迎转载,但请务必注明原文出处及作者信息. 深入MNIST refer: http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mni ...
- 用keras的cnn做人脸分类
keras介绍 Keras是一个简约,高度模块化的神经网络库.采用Python / Theano开发. 使用Keras如果你需要一个深度学习库: 可以很容易和快速实现原型(通过总模块化,极简主义,和可 ...
- tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...
随机推荐
- python学习日记(OOP——@property)
在绑定属性时,如果我们直接把属性暴露出去,虽然写起来很简单,但是,没办法检查参数,导致可以把成绩随便改: s = Student() s.score = 9999 这显然不合逻辑.为了限制score的 ...
- beego框架的最简单登入演示
一.controllers逻辑代码 func (c *UserController) Get() { c.TplName="login.html" } func (c *UserC ...
- ASP.NET概念
ASP.NET :是一个开发框架,用于通过 HTML.CSS.JavaScript 以及服务器脚本来构建网页和网站. ASP.NET两种开发语言:VB C#
- tty
tty一词源于Teletypes,或teletypewriters,原来指的是电传打字机,是通过串行线用打印机键盘通过阅读和发送信息的东西,后来这东西被键盘和显示器取代,所以现在叫终端比较合适. 终端 ...
- zabbix Server 4.0 部署及之内置item使用案例
zabbix Server 4.0 部署及之内置item使用案例 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.zabbix组件架构概述(图片摘自网络) 1>.zabbi ...
- EF CodeFirst系列(2)---CodeFirst的数据库初始化
1. CodeFirst的默认约定 1.领域类和数据库架构的映射约定 在介绍数据库的初始化之前我们需要先了解领域类和数据库之间映射的一些约定.在CodeFirst模式中,约定指的是根据领域类(如Stu ...
- 中间件方法必须返回Response对象实例(tp5.1+小程序结合时候出的问题)
前言:在最近开发小程序通过中间件检查是否携带token时候报的一个错误 解决方法: 根据手册中需要return出去才可以不报错
- [物理学与PDEs]第1章第2节 预备知识 2.2 Ampere-Biot-Savart 定律, 静磁场的散度与旋度
1. 电流密度, 电荷守恒定律 (1) 电荷的定向移动形成电流. (2) 电流密度 ${\bf j}$, 是描述导体内一点在某一时刻电流流动情况的物理量, 用单位时间内通过垂直于电流方向的单位面积的电 ...
- java Concurrent并发容器类 小结
Java1.5提供了多种并发容器类来改进同步容器的性能. 同步容器将所有对容器的访问都串行化,以实现他们的线程安全性.这种方法的代价是严重降低并发性,当多个线程竞争容器的锁时,吞吐量将严重减低. 一 ...
- mysql常见的问题
1.为什么选择某一个版本 各个版本之间的区别及优缺点 首先,服务器特性 mysql percona mysql mariaDB 开源 开源 开源 支持分区表 支持分区表 支持分区表 innodb Xt ...