pytorch1.0实现RNN-LSTM for Classification
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt # 超参数
# Hyper Parameters
# 训练整批数据多少次, 为了节约时间, 只训练一次
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height 时间步数 / 图片高度
INPUT_SIZE = 28 # rnn input size / image width 每步输入值 / 图片每行像素
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist 手写数字
# Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
) # plot one example
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() # 数据加载
# Data Loader for easy mini-batch return in training 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试数据
# convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array # 用一个 class 来建立 RNN 模型.
# 这个 RNN 整体流程:
# (input0, state0) -> LSTM -> (output0, state1);
# (input1, state1) -> LSTM -> (output1, state2);
# …
# (inputN, stateN)-> LSTM -> (outputN, stateN+1);
# outputN -> Linear -> prediction.
# 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
# 定义神经网络
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns LSTM 效果要比 nn.RNN() 好多了
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer 有几层 RNN layers
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
)
# 输出层
self.out = nn.Linear(64, 10) def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) # LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # 选取最后一个时间点的 r_out 输出
# choose r_out at the last time step
out = self.out(r_out[:, -1, :]) # 这里 r_out[:, -1, :] 的值也是 h_n 的值
return out rnn = RNN()
print(rnn)
# 选择优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # 训练和测试
# 将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入,
# 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了.
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data
b_x = b_x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size) output = rnn(b_x) # rnn output # 喂给 rnn net 训练数据 b_x, 输出预测值
loss = loss_func(output, b_y) # cross entropy loss # 计算两者的误差
optimizer.zero_grad() # clear gradients for this training step # 清空上一步的残余更新参数值
loss.backward() # backpropagation, compute gradients # 误差反向传播, 计算参数更新值
optimizer.step() # apply gradients # 将参数更新值施加到 rnn net 的 parameters 上 if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) # print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
pytorch1.0实现RNN-LSTM for Classification的更多相关文章
- pytorch1.0实现RNN for Regression
import torch from torch import nn import numpy as np import matplotlib.pyplot as plt # 超参数 # Hyper P ...
- RNN/LSTM/GRU/seq2seq公式推导
概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...
- 时间序列(六): 炙手可热的RNN: LSTM
目录 炙手可热的LSTM 引言 RNN的问题 恐怖的指数函数 梯度消失* 解决方案 LSTM 设计初衷 LSTM原理 门限控制* LSTM 的 BPTT 参考文献: 炙手可热的LSTM 引言 上一讲说 ...
- [NL系列] RNN & LSTM 网络结构及应用
http://www.jianshu.com/p/f3bde26febed/ 这篇是 The Unreasonable Effectiveness of Recurrent Neural Networ ...
- RNN,LSTM,GRU基本原理的个人理解
记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...
- 用pytorch1.0快速搭建简单的神经网络
用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- pytorch1.0 用torch script导出模型
python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...
随机推荐
- Sklearn多元线性回归
Sklearn多元线性回归 1 正文 2 参考资料 Sklearn多元线性回归
- WSL2(Ubuntu)安装Docker
原文链接:https://www.cnblogs.com/blog5277/p/12071400.html 原文作者:博客园--曲高终和寡 *******************如果你看到这一行,说明 ...
- docker技术入门(2)
接上一篇文章 [容器技术]Docker容器技术入门(一) 今天接着上次聊一聊有关Docker网络.数据存储相关的技术点 Docker网络模式 01 Dokcer 通过使用 Linux 桥接提供容器之间 ...
- postman上传excel,java后台读取excel生成到指定位置进行备份,并且把excel中的数据添加到数据库
最近要做个前端网页上传excel,数据直接添加到数据库的功能..在此写个读取excel的demo. 首先新建springboot的web项目 导包,读取excel可以用poi也可以用jxl,这里本文用 ...
- Ibatis自动解决sql注入机制
疑问1:为什么IBatis解决了大部分的sql注入?(实际上还有部分sql语句需要关心sql注入,比如like) 之前写Java web,一直使用IBatis,从来没有考虑过sql注入:最近写php( ...
- 熔断机制hystrix
一.问题产生 雪崩效应:是一种因服务提供者的不可用导致服务调用者的不可用,并将不可用逐渐放大的过程 正常情况下的服务: 某一服务出现异常,拖垮整个服务链路,消耗整个线程队列,造成服务不可用,资源耗尽: ...
- html5中output元素详解
html5中output元素详解 一.总结 一句话总结: output元素是HTML5新增的元素,用来设置不同数据的输出,没什么大用,了解即可 <form action="L3_01. ...
- MySQL Error 1170 (42000): BLOB/TEXT Column Used in Key Specification Without a Key Length【转】
今天有开发反应他的建表语句错误,我看了下,提示: MySQL Error 1170 (42000): BLOB/TEXT Column Used in Key Specification Withou ...
- XML 中的 xmlns 等属性的意义
原文:https://blog.csdn.net/lengxiao1993/article/details/77914155 Maven 是一个 java 开发人员很难绕过的构建工具, 因为有众多的开 ...
- git如何将旧commit的相关notes复制到新commit?
答: git notes copy <old-commit> <new-commit>