pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. 云HBase发布全文索引服务,轻松应对复杂查询

    云HBase发布了“全文索引服务”功能,自2019年01月25日后创建的云HBase实例,可以在控制台免费开启此“全文索引服务”功能.使用此功能可以让用户在HBase之上构建功能更丰富的搜索业务,不再 ...

  2. springcloud情操陶冶-springcloud config server(二)

    承接前文springcloud情操陶冶-springcloud config server(一),本文将在前文的基础上讲解config server的涉外接口 前话 通过前文笔者得知,cloud co ...

  3. 关于CSS引入方式的详细见解

    关于CSS的发展史这里不做介绍.写博客的原因之一是想帮助那些与我一样喜欢纠结的初入前端的伙伴,希望自己写的帖子能对伙伴有些许帮助:原因之二这些帖子也算自己的一个知识的整理.现在还没有一定的顺序可循,但 ...

  4. 简述C#中IO的应用

    在.NET Framework 中. System.IO 命名空间主要包含基于文件(和基于内存)的输入输出(I/O)服务的相关基础类库.和其他命名空间一样. System.IO 定义了一系列类.接口. ...

  5. android 资源

    在进行APP开发的过程当中,会用到许多资源,比如:图片,字符串等.现对android资源知识进行简单记录. 具体的详细信息及用法,点击查看官方文档 分类      一般android资源分为可直接访问 ...

  6. Docker+SpringBoot远程发布

    Docker+SpringBoot远程发布 发布成功后启动: docker run -di --name demo1.1 -p 8080:8085 demo:1.0 docker run 命令大全:h ...

  7. 虹软免费人脸识别SDK注册指南

    成为开发者三步完成账号的基本注册与认证:STEP1:点击注册虹软AI开放平台右上角注册选项,完成注册流程.STEP2:首次使用,登录后进入开发者中心,点击账号管理完成企业或者个人认证,若未进行实名认证 ...

  8. Android视频录制从不入门到入门系列教程(三)————视频方向

    运行Android视频录制从不入门到入门系列教程(二)————显示视频图像中的Demo后,我们应该能发现视频的方向是错误的. 由于Android中,Camera给我们的视频图片的原始方向是下图这个样子 ...

  9. 用markdown写博客

    目录 用markdown写博客 前言 标题 段落 引用区块 代码块 列表 分隔线 链接 强调.加粗.下划线.删除线 图片 智能链接 表格 转义序列 用markdown写博客 前言 博客园支持用mark ...

  10. 如何将Eclipse的javaWeb项目改为IDEA的maven项目

    1.首先去IDEA开发工具创建一个maven项目,把该项目改为Web项目, a.在pom.xml中,添加packaging标签,值为war b.右键File,选中project structure, ...