pytorch nn.LSTM()参数详解
输入数据格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)
输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)
import torch
import torch.nn as nn
from torch.autograd import Variable
#构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size)
rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers)
h0 = torch.randn(2,3,20) ->(num_layers* 1,batch_size,hidden_size)
c0 = torch.randn(2,3,20) ->(num_layers*1,batch_size,hidden_size)
num_directions=1 因为是单向LSTM
'''
Outputs: output, (h_n, c_n)
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
batch_first: 输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;
inputs = torch.randn(5,3,10) :seq_len=5,bitch_size=3,input_size=10
我的理解:有3个句子,每个句子5个单词,每个单词用10维的向量表示;而句子的长度是不一样的,所以seq_len可长可短,这也是LSTM可以解决长短序列的特殊之处。只有seq_len这一参数是可变的。
关于hn和cn一些参数的详解看这里
而在遇到文本长度不一致的情况下,将数据输入到模型前的特征工程会将同一个batch内的文本进行padding使其长度对齐。但是对齐的数据在单向LSTM甚至双向LSTM的时候有一个问题,LSTM会处理很多无意义的填充字符,这样会对模型有一定的偏差,这时候就需要用到函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()
详情解释看这里
BiLSTM
BILSTM是双向LSTM;将前向的LSTM与后向的LSTM结合成LSTM。视图举例如下:
LSTM结构推导:
更详细公式推导https://blog.csdn.net/songhk0209/article/details/71134698
GRU公式推导:(网上的图看着有点费劲,就自己画了个数据流图)
---------------------
作者:向阳争渡
来源:CSDN
原文:https://blog.csdn.net/yangyang_yangqi/article/details/84585998
版权声明:本文为博主原创文章,转载请附上博文链接!
pytorch nn.LSTM()参数详解的更多相关文章
- tcpdump常用参数详解
tcpdump常用参数详解 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 好久没有更新我的博客了,看来自己最近还没有在放假中回过神来啊,哈哈~是不是也有小伙伴跟我一样呢?回归正题, ...
- tcpdump的使用以及参数详解
平时分析客户端和服务器网络交互的问题时,很多情况下需要在客户端和服务器抓包分析报文.一般win下抓包使用WireShark即可,但是linux下就需要用到tcpdump了,下面是一些对于tcpdump ...
- Nginx主配置参数详解,Nginx配置网站
1.Niginx主配置文件参数详解 a.上面博客说了在Linux中安装nginx.博文地址为:http://www.cnblogs.com/hanyinglong/p/5102141.html b.当 ...
- iptables参数详解
iptables参数详解 搬运工:尹正杰 注:此片文章来源于linux社区. Iptalbes 是用来设置.维护和检查Linux内核的IP包过滤规则的. 可以定义不同的表,每个表都包含几个内部的链,也 ...
- chattr的常用参数详解
chattr的常用参数详解 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 在实际生产环境中,有的运维工程师不得不和开发和测试打交道,在我们公司最常见的就是部署接口.每天每个人部署的 ...
- mha配置参数详解
mha配置参数详解: 参数名字 是否必须 参数作用域 默认值 示例 hostname Yes Local Only - hostname=mysql_server1, hostname=192.168 ...
- $.ajax()方法所有参数详解;$.get(),$.post(),$.getJSON(),$.ajax()详解
[一]$.ajax()所有参数详解 url: 要求为String类型的参数,(默认为当前页地址)发送请求的地址. type: 要求为String类型的参数,请求方式(post或get)默认为get.注 ...
- linux PHP 编译安装参数详解
linux PHP 编译安装参数详解 ./configure --prefix=/usr/local/php --with-config-file-path=/usr/local/php/etc -- ...
- 【转】jqGrid 各种参数 详解
[原文]http://www.cnblogs.com/younggun/archive/2012/08/27/2657922.htmljqGrid 各种参数 详解 JQGrid JQGrid是一个 ...
随机推荐
- linux 下批量在多文件中替换字符串
sed -i "s/原字符串/新字符串/g" `grep 原字符串 -rl 所在目录` 注意:`` 符号在shell里面正式的名称叫做backquote , 一般叫做命令替换其作用 ...
- python-None
今天偶然间,有人问了一个问题,项目中出现了一个这样的错误. 看到后,就想到是前后数据类型不一致.当时他定义了一些默认初始值为None(刚接触python代码,之前是c,java),然后就后边出现了这样 ...
- 【Leetcode 二分】 滑动窗口中位数(480)
题目 中位数是有序序列最中间的那个数.如果序列的大小是偶数,则没有最中间的数:此时中位数是最中间的两个数的平均数. 例如: [2,3,4],中位数是 3 [2,3],中位数是 (2 + 3) / 2 ...
- ServletConfig详解 (转载)
ServletConfig详解 (转载) 容器初始化一个servlet时,会为这个servlet建一个唯一的ServletConfig.容器从DD读出Servlet初始化参数,并把这些参数交给S ...
- Directx11教程(13) D3D11管线(1)
原文:Directx11教程(13) D3D11管线(1) 从本篇教程开始,我们暂停代码的学习,先来了解一下D3D11的管线,这些管线不涉及具体的硬件,而是着重于理解能够支持D3D11的管 ...
- 学习iOS设计--iOS8的颜色、文字和布局学习
在去年,Apple针对新时代用户彻底更新了其设计语言.现在的设计语言相对之前大为简化,能够让设计师将精力集中到动画和功能上,而不是繁复的视觉细节上. 很多人都曾问过我:设计应当如何入门?成为一名优秀设 ...
- Flask 第二篇
Flask 中的 Render Redirect HttpResponse 1.Flask中的HTTPResponse 在Flask 中的HttpResponse 在我们看来其实就是直接返回字符串 2 ...
- 笔记:在 Windows 10 WSL Ubuntu 18.04 安装 Odoo12 (2019-06-09)
笔记:在 Windows 10 WSL Ubuntu 18.04 安装 Odoo12 原因 为了和服务器一样的运行环境. 使用 Ubuntu 运行 Odoo 运行更快. 方便使用 Windows 10 ...
- select引起的服务端程序崩溃问题
现象: 某个线上的服务最近频繁崩溃.该服务使用C++编写,是个网络服务端程序.作为TCP服务端,接收和转发客户端发来的消息,并给客户端发送消息.该服务跑在CentOS上,8G内存.线上环境中,与客户端 ...
- SELinux 宽容模式(permissive) 强制模式(enforcing) 关闭(disabled) 几种模式之间的转换
http://blog.sina.com.cn/s/blog_5aee9eaf0100y44q.html 在CentOS6.2 中安装intel 的c++和fortran 的编译器时,遇到来一个关于S ...