pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. 数据结构系列(2)之 AVL 树

    本文将主要讲解平衡二叉树中的 AVL 树,其中将重点讲解二叉树的重平衡方法,即左旋和右旋,以及 3+4 重构:这些方法都是后面要讲的 B 树,红黑树等 BBST 的重要基础:此外在看本文之前最好先看一 ...

  2. GoLang structTag说明

    在处理json格式字符串的时候,经常会看到声明struct结构的时候,属性的右侧还有小米点括起来的内容.形如 type User struct { UserId int `json:"use ...

  3. Keepalived 的使用

    1.什么是keepalived Keepalived的作用是检测web服务器的状态,如果有一台web服务器死机,或工作出现故障,Keepalived将检测到,并将有故障的web服务器从系统中剔除,当w ...

  4. Web后端 JAVA学习之路

    1.Java分类 Java按应用来分,可以分为J2ME(手机版),J2SE(标准版),J2EE(企业版)三部分. ・J2ME:已经被安卓开发取代. ・J2SE:Java的核心类,其中包括桌面应用,但一 ...

  5. 开源GIS知识

    ---恢复内容开始--- 2.1.3组件层 数据库组件层按照功能可分为两类:数据管理组件和分析组件. 2.1.3.1数据管理组件 (1)GDAL GDAL(http://www.gdal.org/)是 ...

  6. Android studio怎么使用自定义的framework而避免冲突报错和点不进去报红。

    文件:xx\project_abc\video\build.gradle保证可以运行到自定义的framework而不报错,可能因为project和module名字相同所以导致下面的路径是绝对路径,其他 ...

  7. datatable动态列处理,重绘表格(敲黑板,划重点!!!我肝了一天半才彻底弄懂这个东西,TAT)

    datatable动态列处理,重绘表格 前言:至于动态列的绘画,我前面博客已经写过了,就是动态列的配置问题,不懂的去我博客看下,今天要写的呢,就是你已经写了一个动态列在datatable,现在你想重新 ...

  8. [20190419]shared latch spin count 2.txt

    [20190419]shared latch spin count 2.txt --//上午测试shared latch XX模式的情况,链接:http://blog.itpub.net/267265 ...

  9. sql 视图学习

    视图语句 在 SQL 中,视图是基于 SQL 语句的结果集的可视化的表. 视图包含行和列,就像一个真实的表.视图中的字段就是来自一个或多个数据库中的真实的表中的字段. 您可以向视图添加 SQL 函数. ...

  10. vtop工具使用分析

    vtop工具可以为esxtop提供图形界面,并且可以显示实时统计数据,对于我们监控esxi主机的需求匹配度很高,同时,相对于vcenter中的数据统计选项实时性更高,操作简便,可作为工作使用 为便于我 ...