pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. RDIFramework.NET V3.3 Web框架主界面新增横向菜单功能

    功能描述 响应重多客户的要求与心声,RDIFramework.NET框架Web版本主界面新增横向菜单功能.横向菜单更加直观,用户可操作与展示的空间更多,符合实际应用要求. 一.效果展示 最终界面效果: ...

  2. 服务器配置java

    先去链接下载jdk or jre(服务器上这个就好) 然后解压 tar 下载的文件,放到/usr/local/java/jdk_xxx下面 -v: 可视化显示进度. Enables verbose m ...

  3. 记一次按需加载和npm模块发布实践

    按需加载 在使用 lodash 的时候我们可以使用这样的代码 //一 import {omit} from "lodash"; //二 import l from "lo ...

  4. 折腾Java设计模式之状态模式

    原文地址 折腾Java设计模式之状态模式 状态模式 在状态模式(State Pattern)中,类的行为是基于它的状态改变的.这种类型的设计模式属于行为型模式.在状态模式中,我们创建表示各种状态的对象 ...

  5. vue-cli3 中跨域解决方案

    此方案只能用于开发环境,线上最好设置同源策略(遇到个后端,装你妈批) 前后端不在同一服务器的情况下,前端要访问后端API,可通过在vue.config.js中配置代理服务器. 0:前提条件 1:安装v ...

  6. php禁用函数设置及查看方法详解

    这篇文章主要介绍了php禁用函数设置及查看方法,结合实例形式分析了php禁用函数的方法及使用php探针查看禁用函数信息的相关实现技巧,需要的朋友可以参考下 本文实例讲述了php禁用函数设置及查看方法. ...

  7. 关于javascript异步

    1.简单的理解 JavaScript是单线程的!总所周知,正常代码是从上而下,一条一条顺序执行的.就好比下楼梯,第一条代码先获得内存或者先执行操作.当遇到漫长的处理操作时(比如读取庞大的文件时,执行大 ...

  8. openlayers三:添加图片和图标

    openlayers添加图片是指: 添加在地图上的图片会跟随地图同步放大缩小 而添加图标是指: 添加在地图上的图片不会跟随地图同步放大缩小 添加图片: 首先初始化图片图层: initImageLaye ...

  9. 华途软件受控XML转EXCEL

    公司加密系统用的是华途的产品.最近公司高层想要重新梳理公司信息安全管理情况,华途加密系统的梳理和优化是重中之重. 今天公司领导要求IT导出目前系统中所有软件.后缀的受控情况,然后IT吭哧吭哧地把华途软 ...

  10. 每天五分钟-javascript数据类型

    javascript数据类型分为基本数据类型与复杂数据类型 基本数据类型包括:string,number,boolean,null,undefined,symbol(es6) 复杂数据类型包括:obj ...