本文中的RNN泛指LSTM,GRU等等

CNN中和RNNbatchSize的默认位置是不同的。

  • CNN中:batchsize的位置是position 0.
  • RNN中:batchsize的位置是position 1.

在RNN中输入数据格式:

对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell(),它只接受序列中的单步输入,必须显式的传入隐藏状态torch.nn.RNN()可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。

  1. 输入大小是三维tensor[seq_len,batch_size,input_dim]
  • input_dim是输入的维度,比如是128
  • batch_size是一次往RNN输入句子的数目,比如是5
  • seq_len是一个句子的最大长度,比如15
    所以千万注意,RNN输入的是序列,一次把批次的所有句子都输入了,得到的ouptuthidden都是这个批次的所有的输出和隐藏状态,维度也是三维。
    **可以理解为现在一共有batch_size个独立的RNN组件,RNN的输入维度是input_dim,总共输入seq_len个时间步,则每个时间步输入到这个整个RNN模块的维度是[batch_size,input_dim]
# 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10,2)
# 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out,ht = rnn_seq(x) #h0可以指定或者不指定

问题1:这里outhtsize是多少呢?
回答out:6 * 3 * 10, ht: 2 * 3 * 10,out的输出维度[seq_len,batch_size,output_dim],ht的维度[num_layers * num_directions, batch, hidden_size],如果是单向单层的RNN那么一个句子只有一个hidden
问题2out[-1]ht[-1]是否相等?
回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元

  1. RNN的其他参数
RNN(input_dim ,hidden_dim ,num_layers ,…)
– input_dim 表示输入的特征维度
– hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
– num_layers 表示网络的层数
– nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
– bias 表示是否使用偏置,默认使用
– batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
– dropout 表示是否在输出层应用 dropout
– bidirectional 表示是否使用双向的 rnn,默认是 False

LSTM的输出多了一个memory单元

# 输入维度 50,隐层100维,两层
lstm_seq = nn.LSTM(50, 100, num_layers=2)
# 输入序列seq= 10,batch =3,输入维度=50
lstm_input = torch.randn(10, 3, 50)
out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态

问题1:out(h,c)的size各是多少?
回答:out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)
问题2:out[-1,:,:]h[-1,:,:]相等吗?
回答: 相等

GRU比较像传统的RNN

gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
out, h = gru_seq(gru_input)

 
 

pytorch, LSTM介绍的更多相关文章

  1. pytorch学习笔记(九):PyTorch结构介绍

    PyTorch结构介绍对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握.水平有限,如有错误,欢迎指错,谢谢! 几个重要的类型和数值相关的Tens ...

  2. 网络流量预测入门(二)之LSTM介绍

    目录 网络流量预测入门(二)之LSTM介绍 LSTM简介 Simple RNN的弊端 LSTM的结构 细胞状态(Cell State) 门(Gate) 遗忘门(Forget Gate) 输入门(Inp ...

  3. LSTM介绍

    转自:https://blog.csdn.net/gzj_1101/article/details/79376798 LSTM网络 long short term memory,即我们所称呼的LSTM ...

  4. RNN LSTM 介绍

    [RNN以及LSTM的介绍和公式梳理]http://blog.csdn.net/Dark_Scope/article/details/47056361 [知乎 对比 rnn  lstm  简单代码] ...

  5. pytorch lstm crf 代码理解 重点

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  6. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  7. Pytorch LSTM 词性判断

    首先,我们定义好一个LSTM网络,然后给出一个句子,每个句子都有很多个词构成,每个词可以用一个词向量表示,这样一句话就可以形成一个序列,我们将这个序列依次传入LSTM,然后就可以得到与序列等长的输出, ...

  8. pytorch LSTM情感分类全部代码

    先运行main.py进行文本序列化,再train.py模型训练 dataset.py from torch.utils.data import DataLoader,Dataset import to ...

  9. RNN、LSTM介绍以及梯度消失问题讲解

    写在最前面,感谢这两篇文章,基本上的框架是从这两篇文章中得到的: https://zhuanlan.zhihu.com/p/28687529 https://zhuanlan.zhihu.com/p/ ...

随机推荐

  1. git && gitlab 使用

    安装略过 使用 基于公钥的认证登录,方便对用户进行权限控制 useradd -s /usr/bin/git-shell testgit #创建一个用户 或者直接useradd testgit 然后去/ ...

  2. ArcGis Python脚本——根据接图表批量裁切分幅影像

    年前写了一个用渔网工具制作图幅接图表的文章,链接在这里: 使用ArcMap做一个1:5000标准分幅图并编号 本文提供一个使用ArcMap利用接图表图斑裁切一幅影像为多幅的方法. 第一步,将接图表拆分 ...

  3. Python通过分页对数据进行展示

    # 通过分页对数据进行展示 """ 要求: 每页显示10条数据 让用户输入要查看的页面:页码 """ USER_LIST = [] for ...

  4. HTML(七)HTML 表单(form元素介绍,input元素的常用type类型,input元素的常用属性)

    前言 表单是网页与用户的交互工具,由一个<form>元素作为容器构成,封装其他任何数量的表单控件,还有其他任何<body>元素里可用的标签 表单能够包含<input> ...

  5. 3D Slicer中文教程(一)—下载及安装方法

    3D Slicer是用于医学图像信息学,图像处理和三维可视化的开源软件平台. 通过国家卫生研究院和全球开发人员社区的支持,二十多年来,Slicer为医生,研究人员和公众提供了免费,强大的跨平台加工工具 ...

  6. NB-IoT不一定最完美 但足以成为决定ofo与摩拜物联网胜负的关键【转】

    转自:http://news.rfidworld.com.cn/2017_11/3d5ed5c5d8cb9949.html 2018年到来之前,如果还不懂物联网,你会被淘汰. 今年1月,工信部< ...

  7. 【原创】大数据基础之词频统计Word Count

    对文件进行词频统计,是一个大数据领域的hello word级别的应用,来看下实现有多简单: 1 Linux单机处理 egrep -o "\b[[:alpha:]]+\b" test ...

  8. 对Java框架spring、hibernate、Struts的粗浅理解

    对 Struts 的理解:1. struts 是一个按 MVC 模式设计的 Web 层框架,其实它就是一个大大的 servlet,这个Servlet 名为 ActionServlet,或是 Actio ...

  9. oAuth2授权协议 & 微信授权登陆和绑定 & 多环境共用一个微信开发平台回调设置

    OAuth2(open Auth)开放授权协议 授权码模式流程: 1.浏览器(客户端)点击一个比如使用微信登陆按钮 2.会跳到认证服务器页面,让用户选择是否授权 3.如果用户点击授权,那么会跳转到开始 ...

  10. python3 基础语法(二)

    一.python3的基本数据类型: 和其他语言一样都包含了以下数据类型: 类型 含义 实例 INT 整型(integer) 1 FLOAT 浮点型 1.1 BOOL 布尔值 TRUE/FALSE ST ...