pytorch1.0实现RNN-LSTM for Classification
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt # 超参数
# Hyper Parameters
# 训练整批数据多少次, 为了节约时间, 只训练一次
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height 时间步数 / 图片高度
INPUT_SIZE = 28 # rnn input size / image width 每步输入值 / 图片每行像素
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist 手写数字
# Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
) # plot one example
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() # 数据加载
# Data Loader for easy mini-batch return in training 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试数据
# convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array # 用一个 class 来建立 RNN 模型.
# 这个 RNN 整体流程:
# (input0, state0) -> LSTM -> (output0, state1);
# (input1, state1) -> LSTM -> (output1, state2);
# …
# (inputN, stateN)-> LSTM -> (outputN, stateN+1);
# outputN -> Linear -> prediction.
# 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
# 定义神经网络
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns LSTM 效果要比 nn.RNN() 好多了
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer 有几层 RNN layers
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
)
# 输出层
self.out = nn.Linear(64, 10) def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) # LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # 选取最后一个时间点的 r_out 输出
# choose r_out at the last time step
out = self.out(r_out[:, -1, :]) # 这里 r_out[:, -1, :] 的值也是 h_n 的值
return out rnn = RNN()
print(rnn)
# 选择优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # 训练和测试
# 将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入,
# 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了.
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data
b_x = b_x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size) output = rnn(b_x) # rnn output # 喂给 rnn net 训练数据 b_x, 输出预测值
loss = loss_func(output, b_y) # cross entropy loss # 计算两者的误差
optimizer.zero_grad() # clear gradients for this training step # 清空上一步的残余更新参数值
loss.backward() # backpropagation, compute gradients # 误差反向传播, 计算参数更新值
optimizer.step() # apply gradients # 将参数更新值施加到 rnn net 的 parameters 上 if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) # print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
pytorch1.0实现RNN-LSTM for Classification的更多相关文章
- pytorch1.0实现RNN for Regression
import torch from torch import nn import numpy as np import matplotlib.pyplot as plt # 超参数 # Hyper P ...
- RNN/LSTM/GRU/seq2seq公式推导
概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...
- 时间序列(六): 炙手可热的RNN: LSTM
目录 炙手可热的LSTM 引言 RNN的问题 恐怖的指数函数 梯度消失* 解决方案 LSTM 设计初衷 LSTM原理 门限控制* LSTM 的 BPTT 参考文献: 炙手可热的LSTM 引言 上一讲说 ...
- [NL系列] RNN & LSTM 网络结构及应用
http://www.jianshu.com/p/f3bde26febed/ 这篇是 The Unreasonable Effectiveness of Recurrent Neural Networ ...
- RNN,LSTM,GRU基本原理的个人理解
记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...
- 用pytorch1.0快速搭建简单的神经网络
用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- pytorch1.0 用torch script导出模型
python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...
随机推荐
- python find和index的区别
如果找不到目标元素,index会报错,find会返回-1 >>> s="hello world" >>> s.find("llo&qu ...
- Linux文件的权限的基本介绍
一. ls -l 显示的内容如下: 二.rwx权限详解 1.rwx作用到文件 2. rwx作用在目录 三.文件及目录实际案例 四.修改权限 - chmod 1. 基本说明: 2.第一种方式 ...
- NOIP提高组初战告捷
前天第一次参加NOIP初赛,竟然提高组考了57分进入复赛啊啊!原本自己估分是52竟然估少了[滑稽]这个成绩 是我们学校初一提高组成绩最高 还是不错(出乎我意料之外)的!
- ubuntu之路——day4(今天主要看了神经网络的概念)
感谢两位老师做的免费公开课: 第一个是由吴恩达老师放在网易云课堂的神经网络和深度学习,比较偏理论,使用numpy包深入浅出的介绍了向量版神经网络的处理方式,当然由于视频有点老,虽然理论很好但是工具有点 ...
- Guided Hacking DLL Injector 3.3
Guided Hacking DLL Injector 3.3 https://guidedhacking.com/resources/guided-hacking-dll-injector.4/ I ...
- css3画半圆的两种方法
<html lang="en"> <head> <meta charset="UTF-8"> <meta name=& ...
- vue-router 使用query传参跳转了两次(首次带参数,跳转到了不带参数)
问题: 在做项目的过程中,使用query传参数,发现跳转过程中第一次有参数,但是路由马上又跳转了一次,然后 ?和它之后的参数都不见了 问题分析: 因为路由加载了两次 解决办法: ·1. 找到总的 la ...
- PHP判断文件大小是MB、GB、TB...
<?php date_default_timezone_set ("PRC" ); function getFilePro($fileName){ if (!file_exi ...
- Docs-.NET-C#-指南-语言参考-预处理器指令:#pragma(C# 参考)
ylbtech-Docs-.NET-C#-指南-语言参考-预处理器指令:#pragma(C# 参考) 1.返回顶部 1. #pragma(C# 参考) 2015/07/20 #pragma 为编译器给 ...
- Leetcode: Stream of Characters
Implement the StreamChecker class as follows: StreamChecker(words): Constructor, init the data struc ...