###仅为自己练习,没有其他用途

 1 import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height
INPUT_SIZE = 28 # rnn input size / image width
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
) # # plot one example
# print(train_data.train_data.size()) # (60000, 28, 28)
# print(train_data.train_labels.size()) # (60000)
# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# plt.title('%i' % train_data.train_labels[0])
# plt.show() # Data Loader for easy mini-batch return in training
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
) self.out = nn.Linear(64, 10) def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size)
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # choose r_out at the last time step
out = self.out(r_out[:, -1, :])
return out rnn = RNN()
print(rnn) optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data
b_x = b_x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size) output = rnn(b_x) # rnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) # print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

pytorch之 RNN classifier的更多相关文章

  1. pytorch实现rnn并且对mnist进行分类

    1.RNN简介 rnn,相比很多人都已经听腻,但是真正用代码操练起来,其中还是有很多细节值得琢磨. 虽然大家都在说,我还是要强调一次,rnn实际上是处理的是序列问题,与之形成对比的是cnn,cnn不能 ...

  2. pytorch之 RNN 参数解释

    上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决. 总述:第一次看到这个函数时,脑袋有点懵,总结了下总 ...

  3. 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) from keras.datasets import mnist fro ...

  4. pytorch之 RNN regression

    关于RNN模型参数的解释,可以参看RNN参数解释 1 import torch from torch import nn import numpy as np import matplotlib.py ...

  5. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  6. Pytorch基础——使用 RNN 生成简单序列

    一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...

  7. RNN,写起来真的烦

    曾经,为了处理一些序列相关的数据,我稍微了解了一点递归网络 (RNN) 的东西.由于当时只会 tensorflow,就从官网上找了一些 tensorflow 相关的 demo,中间陆陆续续折腾了两个多 ...

  8. [转] Torch中实现mini-batch RNN

    工作中需要把一个SGD的LSTM改造成mini-batch的LSTM, 两篇比较有用的博文,转载mark https://zhuanlan.zhihu.com/p/34418001 http://ww ...

  9. RNN网络【转】

    本文转载自:https://zhuanlan.zhihu.com/p/29212896 简单的Char RNN生成文本 Sherlock I want to create some new thing ...

随机推荐

  1. Java 设置Word页边距、页面大小、页面方向、页面边框

    本文将通过Java示例介绍如何设置Word页边距(包括上.下.左.右).页面大小(可设置Letter/A3/A4/A5/A6/B4/B5/B6/Envelop DL/Half Letter/Lette ...

  2. vue vuex开发中遇到的问题及解决小技巧

    1.在vue的开发中,如果使用了vuex,数据的组装,修改时在mutations中,页面是建议修改变量值的,如果强制修改,控制台就会出现错误.如下: 这种错误虽然不会影响结果,但是是vuex不提倡的方 ...

  3. VMware Workstation CentOS7 Linux 学习之路(1)--系统安装

      前言 很早就想学习Linux了,出去面试很多家公司都问会不会Linux,都很尴尬,一直没学过Linux,在网上也看过很多资料,也安装了VM,自己摸索着学习Linux,之前看网上的一些命令一顿操作, ...

  4. SpringCloud之Feign(五)

    Feign简介 Feign 是一个声明web服务客户端,这便得编写web服务客户端更容易,使用Feign 创建一个接口并对它进行注解,它具有可插拔的注解支持包括Feign注解与JAX-RS注解,Fei ...

  5. 《C# 爬虫 破境之道》:第一境 爬虫原理 — 第一节:整体思路

    在构建本章节内容的时候,笔者也在想一个问题,究竟什么样的采集器框架,才能算得上是一个“全能”的呢?就我自己以往项目经历而言,可以归纳以下几个大的分类: 根据通讯协议:HTTP的.HTTPS的.TCP的 ...

  6. pyhton 线程锁

    问题:已经有了全局解释器锁为什么还需要锁? 答:全局解释器锁是在Cpython解释器下,同一时刻,多个线程只能有一个线程被cpu调度 它是在线程和cpu之间加锁,线程和cpu之间有传递时间,即使有GI ...

  7. MyBatis3——输出参数ResultType、动语态sql

    输出参数ResultType 1.输出参数为简单类型(8个基本+String) 2.输出参数为对象类型 3.输出参数为实体对象类型的集合:虽然输出类型为集合,但是resultType依然写集合的元素类 ...

  8. [bzoj1041] [洛谷P2508] [HAOI2008] 圆上的整点

    Description 求一个给定的圆(x^2+y^2=r^2),在圆周上有多少个点的坐标是整数. Input 只有一个正整数n,n<=2000 000 000 Output 整点个数 Samp ...

  9. CTF-Keylead(ASIS CTF 2015)

    将keylead下载到本地用7-ZIP打开,发现主要文件 keylead~ 在ubuntu里跑起来,发现是个游戏,按回车后要摇出3,1,3,3,7就能获得flag. 拖进IDA 直接开启远程调试,跑起 ...

  10. java面试| 线程面试题集合

    集合的面试题就不罗列了,基本上在深入理解集合系列已覆盖 「 深入浅出 」java集合Collection和Map 「 深入浅出 」集合List 「 深入浅出 」集合Set 这里搜罗网上常用线程面试题, ...