使用torch实现RNN
(本文对https://blog.csdn.net/out_of_memory_error/article/details/81456501的结果进行了复现。)
在实验室的项目遇到了困难,弄不明白LSTM的原理。到网上搜索,发现LSTM是RNN的变种,那就从RNN开始学吧。
带隐藏状态的RNN可以用下面两个公式来表示:

可以看出,一个RNN的参数有W_xh,W_hh,b_h,W_hq,b_q和H(t)。其中H(t)是步数的函数。
参考的文章考虑了这样一个问题,对于x轴上的一列点,有一列sin值,我们想知道它对应的cos值,但是即使sin值相同,cos值也不同,因为输出结果不仅依赖于当前的输入值sinx,还依赖于之前的sin值。这时候可以用RNN来解决问题
用到的核心函数:torch.nn.RNN() 参数如下:
input_size – 输入
x的特征数量。hidden_size – 隐藏层的特征数量。
num_layers – RNN的层数。
nonlinearity – 指定非线性函数使用
tanh还是relu。默认是tanh。bias – 如果是
False,那么RNN层就不会使用偏置权重 $b_ih$和$b_hh$,默认是Truebatch_first – 如果
True的话,那么输入Tensor的shape应该是[batch_size, time_step, feature],输出也是这样。dropout – 如果值非零,那么除了最后一层外,其它层的输出都会套上一个
dropout层。bidirectional – 如果
True,将会变成一个双向RNN,默认为False。
下面是代码:
# encoding:utf-8
import torch
import numpy as np
import matplotlib.pyplot as plt # 导入作图相关的包
from torch import nn # 定义RNN模型
class Rnn(nn.Module):
def __init__(self, INPUT_SIZE):
super(Rnn, self).__init__() # 定义RNN网络,输入单个数字.隐藏层size为[feature, hidden_size]
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True # 注意这里用了batch_first=True 所以输入形状为[batch_size, time_step, feature]
)
# 定义一个全连接层,本质上是令RNN网络得以输出
self.out = nn.Linear(32, 1) # 定义前向传播函数
def forward(self, x, h_state):
# 给定一个序列x,每个x.size=[batch_size, feature].同时给定一个h_state初始状态,RNN网络输出结果并同时给出隐藏层输出
r_out, h_state = self.rnn(x, h_state)
outs = []
for time in range(r_out.size(1)): # r_out.size=[1,10,32]即将一个长度为10的序列的每个元素都映射到隐藏层上.
outs.append(self.out(r_out[:, time, :])) # 依次抽取序列中每个单词,将之通过全连接层并输出.r_out[:, 0, :].size()=[1,32] -> [1,1]
return torch.stack(outs, dim=1), h_state # stack函数在dim=1上叠加:10*[1,1] -> [1,10,1] 同时h_state已经被更新 TIME_STEP = 10
INPUT_SIZE = 1
LR = 0.02 model = Rnn(INPUT_SIZE)
print(model) loss_func = nn.MSELoss() # 使用均方误差函数
optimizer = torch.optim.Adam(model.parameters(), lr=LR) # 使用Adam算法来优化Rnn的参数,包括一个nn.RNN层和nn.Linear层 h_state = None # 初始化h_state为None for step in range(300):
# 人工生成输入和输出,输入x.size=[1,10,1],输出y.size=[1,10,1]
start, end = step * np.pi, (step + 1)*np.pi steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps) x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis]) # 将x通过网络,长度为10的序列通过网络得到最终隐藏层状态h_state和长度为10的输出prediction:[1,10,1]
prediction, h_state = model(x, h_state)
h_state = h_state.data # 这一步只取了h_state.data.因为h_state包含.data和.grad 舍弃了梯度
# 反向传播
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward() # 优化网络参数具体应指W_xh, W_hh, b_h.以及W_hq, b_q
optimizer.step() # 对最后一次的结果作图查看网络的预测效果
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.show()
最后一步预测和实际y的结果作图如下:

可看出,训练RNN网络之后,对网络输入一个序列sinx,能正确输出对应的序列cosx
使用torch实现RNN的更多相关文章
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- pytorch实现rnn并且对mnist进行分类
1.RNN简介 rnn,相比很多人都已经听腻,但是真正用代码操练起来,其中还是有很多细节值得琢磨. 虽然大家都在说,我还是要强调一次,rnn实际上是处理的是序列问题,与之形成对比的是cnn,cnn不能 ...
- 从网络架构方面简析循环神经网络RNN
一.前言 1.1 诞生原因 在普通的前馈神经网络(如多层感知机MLP,卷积神经网络CNN)中,每次的输入都是独立的,即网络的输出依赖且仅依赖于当前输入,与过去一段时间内网络的输出无关.但是在现实生活中 ...
- [PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
- Pytorch基础——使用 RNN 生成简单序列
一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...
- pytorch RNN层api的几个参数说明
classtorch.nn.RNN(*args, **kwargs) input_size – The number of expected features in the input x hidde ...
- 【学习笔记】RNN算法的pytorch实现
一些新理解 之前我有个疑惑,RNN的网络窗口,换句话说不也算是一个卷积核嘛?那所有的网络模型其实不都是一个东西吗?今天又听了一遍RNN,发现自己大错特错,还是没有学明白阿.因为RNN的窗口所包含的那一 ...
- DeepLearning常用库简要介绍与对比
网上近日流传一张DL相关库在Github上的受关注度对比(数据应该是2016/03/15左右统计的): 其中tensorflow,caffe,keras和Theano排名比较靠前. 今日组会报告上tj ...
- pytorch, LSTM介绍
本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...
随机推荐
- uni-app运行到浏览器跨域H5页面的跨域问题解决方案
官方文档对跨域的解决方案推荐: https://ask.dcloud.net.cn/article/35267 更方便的解决方案 项目根目录直接创建一个vue.config.js文件,并在里面配置代理 ...
- 编辑器、编译器、文件、IDE等常见概念辨析
一.编辑器与编译器 1.编辑器与编译器有什么区别? 简单讲,编译器就是将"一种语言(通常为高级语言)"翻译为"另一种语言(通常为低级语言)"的程序.一个现代编译 ...
- 整理总结数据库常用sql语句,建议收藏,忘记了可以来看一下
第一节课:sql语言介绍(参照PPT)及基本查询sql学习 1.数据库表的介绍 emp表:员工表 dept表:部门表 salgrady:薪资水平表 Balance: 2.基本的查询语句: 知识点: s ...
- 《计算机网络》课程笔记 (Ch03-运输层)
为运行在不同主机上的应用进程之间提供逻辑通信功能. 将应用层报文切分为块,然后加上运输层首部,形成报文段,交付给网络层. 多路复用与多路分解 将网络层提供的主机到主机交付服务延伸到进程到进程交付服务. ...
- MRCTF 2020 WP
MRCTF 2020 WP 引言 周末趁上课之余,做了一下北邮的CTF,这里记录一下做出来的几题的WP ez_bypass 知识点:MD5强类型比较,is_numeric()函数绕过 题目源码: I ...
- 使用fileupload组件
1. 进行文件上传时, 表单需要做的准备: 1). 请求方式为 POST: <form action="uploadServlet" method="post&qu ...
- 中文的csv文件的编码改成utf-8的方法
直奔主题:把包含中文的csv文件的编码改成utf-8的方法: https://stackoverflow.com/questions/191359/how-to-convert-a-file-to-u ...
- Unity 游戏框架搭建 2019 (五十二~五十四) 什么是库?&第四章总结&第五章简介
在上一篇,我们对框架和架构进行了一点探讨.我们在这一篇再接着探讨. 什么是库呢? 来自同一位大神的解释: 库, 插到 既有 架构 中, 补充 特定 功能. 很形象,库就是搞这个的.我们的库最初存在的目 ...
- F5忘记密码修改教程
!!!首先查看系统版本,13版本和14版本修改密码方式不一致 首先介绍13版本修改密码 注:12版本也适用,11版本未测试,应该也可以,有问题欢迎留言) 1. 将终端连接到BIG-IP串行控制台端口. ...
- 【Redis面试题】如何使用Redis实现微信步数排行榜?
1. 前言 之前写过一篇博客,讲解的是Redis的5种数据结构及其常用命令,当时有读者评论,说希望了解下这5种数据结构各自的使用场景,不过一直也没来得及写. 碰巧,在3月份找工作面试时,有个面试官先问 ...