[PyTorch] rnn,lstm,gru中输入输出维度
本文中的RNN泛指LSTM,GRU等等
CNN中和RNN中batchSize的默认位置是不同的。
- CNN中:batchsize的位置是
position 0
. - RNN中:batchsize的位置是
position 1
.
在RNN中输入数据格式:
对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell()
,它只接受序列中的单步输入,必须显式的传入隐藏状态。torch.nn.RNN()
可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。
- 输入大小是三维tensor
[seq_len,batch_size,input_dim]
input_dim
是输入的维度,比如是128
batch_size
是一次往RNN
输入句子的数目,比如是5
。seq_len
是一个句子的最大长度,比如15
所以千万注意,RNN
输入的是序列,一次把批次的所有句子都输入了,得到的ouptut
和hidden
都是这个批次的所有的输出和隐藏状态,维度也是三维。
**可以理解为现在一共有batch_size
个独立的RNN
组件,RNN
的输入维度是input_dim
,总共输入seq_len
个时间步,则每个时间步输入到这个整个RNN
模块的维度是[batch_size,input_dim]
# 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10,2)
# 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out,ht = rnn_seq(x) #h0可以指定或者不指定
问题1:这里out
、ht
的size是多少呢?
回答:out
:6 * 3 * 10, ht
: 2 * 3 * 10,out
的输出维度[seq_len,batch_size,output_dim]
,ht的维度[num_layers * num_directions, batch, hidden_size]
,如果是单向单层的RNN那么一个句子只有一个hidden
。
问题2:out[-1]
和ht[-1]
是否相等?
回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元
RNN
的其他参数
RNN(input_dim ,hidden_dim ,num_layers ,…)
– input_dim 表示输入的特征维度
– hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
– num_layers 表示网络的层数
– nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
– bias 表示是否使用偏置,默认使用
– batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
– dropout 表示是否在输出层应用 dropout
– bidirectional 表示是否使用双向的 rnn,默认是 False

LSTM的输出多了一个memory单元
# 输入维度 50,隐层100维,两层
lstm_seq = nn.LSTM(50, 100, num_layers=2)
# 输入序列seq= 10,batch =3,输入维度=50
lstm_input = torch.randn(10, 3, 50)
out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态
问题1:out
和(h,c)
的size各是多少?
回答:out
:(10 * 3 * 100),(h,c)
:都是(2 * 3 * 100)
问题2:out[-1,:,:]
和h[-1,:,:]
相等吗?
回答: 相等
GRU比较像传统的RNN
gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
out, h = gru_seq(gru_input)
作者:VanJordan
链接:https://www.jianshu.com/p/b942e65cb0a3
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
[PyTorch] rnn,lstm,gru中输入输出维度的更多相关文章
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- RNN,LSTM,GRU基本原理的个人理解
记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...
- RNN/LSTM/GRU/seq2seq公式推导
概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...
- RNN - LSTM - GRU
循环神经网络 (Recurrent Neural Network,RNN) 是一类具有短期记忆能力的神经网络,因而常用于序列建模.本篇先总结 RNN 的基本概念,以及其训练中时常遇到梯度爆炸和梯度消失 ...
- RNN & LSTM & GRU 的原理与区别
RNN 循环神经网络,是非线性动态系统,将序列映射到序列,主要参数有五个:[Whv,Whh,Woh,bh,bo,h0][Whv,Whh,Woh,bh,bo,h0],典型的结构图如下: 和普通神经网 ...
- RNN, LSTM, GRU cells
项目需要,先简记cell,有时间再写具体改进原因 RNN cell LSTM cell: GRU cell: reference: 1.https://towardsdatascience.com/a ...
- 【pytorch学习笔记0】-CNN与LSTM输入输出维度含义
卷积data的四个维度: batch, input channel, height, width Conv2d的四个维度: input channel, output channel, kernel, ...
- NLP教程(5) - 语言模型、RNN、GRU与LSTM
作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/36 本文地址:http://www.showmeai.tech/article-det ...
- RNN,GRU,LSTM
2019-08-29 17:17:15 问题描述:比较RNN,GRU,LSTM. 问题求解: 循环神经网络 RNN 传统的RNN是维护了一个隐变量 ht 用来保存序列信息,ht 基于 xt 和 ht- ...
随机推荐
- hdu5335(搜索)
Walk Out Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total Su ...
- PHP小函数集-篇一
一. 验证 /** * 判断用户名是否规范 */ function is_username($username) { if (preg_match("/^[a-zA-Z]{1}([0-9a- ...
- MFC project for a non-Unicode character set is deprecated
error MSB8031: Building an MFC project for a non-Unicode character set is deprecated. You must chang ...
- PCB CAM自动化后台配置说明
CAM自动化项目经历9个月时间里,在我们IT团队与工程部深入合作下,依据PCB各种场景定制特定功能,且这几个月里在不断的改进与迭代脚本功能,在此期间攻破了一个又一个难题,最终项目第一阶段已顺立上线运行 ...
- java数组与字符串相互转换、整型与字符串相互转换【详解】
java 数组->字符串 1.char数组(字符数组)->字符串 可以通过:使用String.copyValueOf(charArray)函数实现. 举例: char[] arr={ ...
- poj 2987 Firing【最大权闭合子图+玄学计数 || BFS】
玄学计数 LYY Orz 第一次见这种神奇的计数方式,乍一看非常不靠谱但是仔细想想还卡不掉 就是把在建图的时候把正权变成w*10000-1,负权变成w*10000+1,跑最大权闭合子图.后面的1作用是 ...
- POJ 2833 The Average(优先队列)
原题目网址:http://poj.org/problem?id=2833 本题中文翻译: 描述 在演讲比赛中,当选手完成演讲时,评委将对他的演出进行评分. 工作人员删除最高成绩和最低成绩,并计算其余成 ...
- 【数据结构(C语言版)系列一】 线性表
最近开始看数据结构,该系列笔记简单记录总结下所学的知识,更详细的推荐博主StrayedKing的数据结构系列,笔记部分也摘抄了博主总结的比较好的内容. 一些基本概念和术语 数据是对客观事物的符号表示, ...
- [TJOI2013]松鼠聚会
Description 有N个小松鼠,它们的家用一个点x,y表示,两个点的距离定义为:点(x,y)和它周围的8个点即上下左右四个点和对角的四个点,距离为1.现在N个松鼠要走到一个松鼠家去,求走过的最短 ...
- [POI2009]石子游戏Kam
Description 有N堆石子,除了第一堆外,每堆石子个数都不少于前一堆的石子个数.两人轮流操作每次操作可以从一堆石子中移走任意多石子,但是要保证操作后仍然满足初始时的条件谁没有石子可移时输掉游戏 ...