nn.LSTM(input_dim,hidden_dim,nums_layer,batch_first)

各参数理解:

  • input_dim:输入的张量维度,表示自变量特征数
  • hidden_dim:输出张量维度
  • bias:True or False 是否使用偏置
  • batch_first:True or False,nn.LSTM 接收的输入是(seq_len,batch_size,input_dim),将batch_first设置为True将输入变为(batch_size,seq_len,input_dim)
  • dropout:除了最后层外都引入随机失活
  • bidirectional:True or False 是否使用双向LSTM

举例:10000个句子,每个句子10个词,batch_size=10,embedding_size=300(input_dim)

此时各个参数为:

  • input_size=embedding_size=300
  • batch=batch_size=10
  • seq_len=10
  • 另外设置hidden_dim=128,num_layers=2
	import torch
import torch.nn as nn
from torch.autograd import Variable rnn = nn.LSTM(input_size=300,hidden_size=128,num_layers=2)
inputs = torch.randn(10,10,300)#输入(seq_len, batch_size, input_size) 序列长度为10 batch_size为10 输入维度为300
h_0 = torch.randn(2,10,128)#(num_layers * num_directions, batch, hidden_size) num_layers = 2 ,batch_size=10 ,hidden_size = 128,如果LSTM的bidirectional=True,num_directions=2,否则就是1,表示只有一个方
c_0 = torch.randn(2,10,128)#c_0和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始细胞状态。h_0,c_0如果不提供,那么默认是0
num_directions=1# 因为是单向LSTM #输出格式为(output,(h_n,c_n))
output,(h_n,c_n) = rnn(inputs,(h0,c0))#输入格式为lstm(input,(h_0, c_0))
print("out:", output.shape)
print("h_n:", h_n.shape)
print("c_n:", c_n.shape) 输出结果:
out: torch.Size([10, 10, 128])
h_n: torch.Size([2, 10, 128])
c_n: torch.Size([2, 10, 128])

输出结果:

  • output的shape为(seq_len=5,batch_size=3,num_directions*hidden_size),hidden_size为20,num_directions为1。它包含的LSTM的最后一层的输出特征(h_t),t是batch_size中每个句子的长度。
  • h_n.shape为(num_directions*num_layers=2,batch_size=3,hidden_size=20)
  • c_n.shape==h_n.shape
  • h_n是句子最后一个单词的隐藏状态,c_n包含句子最后一个单词的细胞状态,它们与句子长度无关
  • LSTM中的隐藏状态就是输出。

pytorch中LSTM各参数理解的更多相关文章

  1. [PyTorch]PyTorch中模型的参数初始化的几种方法(转)

    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本文目录 1. xavier初始化 2. kaiming初始化 3. 实际使用中看到的初始化 3.1 ResNeXt,de ...

  2. Pytorch中stack()方法的理解

    Torch.stack() 1. 概念 在一个新的维度上连接一个张量序列 2. 参数 tensors (sequence)需要连接的张量序列 dim (int)在第dim个维度上连接 注意输入的张量s ...

  3. CreateProcess中的部分参数理解

    函数原型,这里写Unicode版本 WINBASEAPIBOOLWINAPICreateProcessW( _In_opt_ LPCWSTR lpApplicationName, //可执行文件名字 ...

  4. pytorch 中LSTM模型获取最后一层的输出结果,单向或双向

    单向LSTM import torch.nn as nn import torch seq_len = 20 batch_size = 64 embedding_dim = 100 num_embed ...

  5. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

  6. Pytorch中的自动求导函数backward()所需参数含义

    摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...

  7. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  8. 如何理解javaSript中函数的参数是按值传递

    本文是我基于红宝书<Javascript高级程序设计>中的第四章,4.1.3传递参数小节P70,进一步理解javaSript中函数的参数,当传递的参数是对象时的传递方式. (结合资料的个人 ...

  9. 关于vue自定义事件中,传递参数的一点理解

    例如有如下场景 先熟悉一下Vue事件处理 <!-- 父组件 --> <template> <div> <!--我们想在这个dealName的方法中传递额外参数 ...

  10. 深入理解python中函数传递参数是值传递还是引用传递

    深入理解python中函数传递参数是值传递还是引用传递 目前网络上大部分博客的结论都是这样的: Python不允许程序员选择采用传值还是传 引用.Python参数传递采用的肯定是"传对象引用 ...

随机推荐

  1. AT_arc041_b 题解

    洛谷链接&Atcoder 链接 本篇题解为此题较简单做法及较少码量,并且码风优良,请放心阅读. 题目简述 给定一个 \(N \times M\) 的矩阵,此矩阵的每一个元素都向上.下.左.右 ...

  2. 从DDPM到DDIM (一) 极大似然估计与证据下界

    从DDPM到DDIM (一) 极大似然估计与证据下界   现在网络上关于DDPM和DDIM的讲解有很多,但无论什么样的讲解,都不如自己推到一遍来的痛快.笔者希望就这篇文章,从头到尾对扩散模型做一次完整 ...

  3. Java 监听POST请求

    要监听POST请求,我们可以使用Java中的HttpServlet类.以下是一个使用Servlet API监听POST请求的完整示例.这个示例使用了Servlet 3.1规范,不需要在web.xml中 ...

  4. JavaScript 中的闭包和事件委托

    包 (Closures) 闭包是 JavaScript 中一个非常强大的特性,它允许函数访问其外部作用域中的变量,即使在该函数被调用时,外部作用域已经执行完毕.闭包可以帮助我们实现数据的私有化.封装和 ...

  5. 3、SpringMVC之RequestMapping注解

    3.1.环境搭建 创建名为spring_mvc_demo的新module,过程参考2.1节 3.1.1.创建SpringMVC的配置文件 <?xml version="1.0" ...

  6. 【Server】对象存储OSS - Minio

    官方文档: https://docs.min.io/docs/minio-quickstart-guide.html 看中文文档CV命令发现下不下来安装包,应该是地址问题 单击搭建非常简单,只有三个步 ...

  7. 【Project】原生JavaWeb工程 02 登陆业务的流程(第一阶段样例)

    1.对用户信息的描述 首先用户有一些基本信息: 最简单的: 用户名称 + 用户密码 然后是用户状态,例如封号,注销,停用,等等 用户名称 + 用户密码 + 账号状态 接着为了防止脚本攻击,又产生了图形 ...

  8. 电视家APP,从此以后电视盒子只是盒子,再与电视毫无关系

    广电总局封掉了电视家APP,于是我决定把我的"当贝盒子"挂咸鱼了,从此以后电视盒子就只是个盒子. PS: 广电的一刀切简直是绝了,绝绝子.

  9. pytorch的显存释放机制torch.cuda.empty_cache()

    参考: https://cloud.tencent.com/developer/article/1626387 据说在pytorch中使用torch.cuda.empty_cache()可以释放缓存空 ...

  10. Longley数据集——强共线性的宏观经济数据,包含GNP deflator(GNP平减指数)、GNP(国民生产总值)、Unemployed(失业率)、ArmedForces(武装力量)、Population(人口)、year(年份),Emlpoyed(就业率)。LongLey数据集因存在严重的多重共线性问题,在早期经常用来检验各种算法或计算机的计算精度

    Longley数据集来自J.W.Longley(1967)发表在JASA上的一篇论文,是强共线性的宏观经济数据,包含GNP deflator(GNP平减指数).GNP(国民生产总值).Unemploy ...