pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. 2.2Bind建立配置文件和实体的映射「深入浅出ASP.NET Core系列」

    希望给你3-5分钟的碎片化学习,可能是坐地铁.等公交,积少成多,水滴石穿,谢谢关注. 新建MVC项目 这次我们没有使用控制台项目,而是使用mvc来测试. 如下图所示,选择空的项目,建完后,记得把项目设 ...

  2. 陪你解读Spring Batch(一)Spring Batch介绍

    前言 整个章节由浅入深了解Spring Batch,让你掌握批处理利器.面对大批量数据毫无惧色.本章只做介绍,后面章节有代码示例.好了,接下来是我们的主角Spring Batch. 1.1 背景介绍 ...

  3. 【Zabbix】CentOS6.9系统下部署Zabbix-server 3.0

    目录 安装Zabbix 关闭selinux 删除旧版本MySQL5.1数据库 安装MySQL 5.6数据库 安装PHP 5.6 Zabbix-server的部署安装 1.安装Yum源 2.安装Zabb ...

  4. 解决Windows对JDK默认版本切换问题

    注意修改path路径,或者修改控制面板下的java控制面板并不有效,原因是由于在WINDOWS\System32环境变量中的优先级高于JAVA_HOME设置的环境变量优先级,故如果只修改环境变量JAV ...

  5. java的设计模式 - 单例模式

    java 面试中单例模式基本都是必考的,都有最推荐的方式,也不知道问来干嘛.下面记录一下 饿汉式(也不知道为何叫这个名字) public class Singleton { private stati ...

  6. SAP 没有激活HUM功能照常可以使用Handling Unit

    SAP 没有激活HUM功能照常可以使用Handling Unit 笔者所在的项目上的公司间STO的流程里,发货公司在做PGI之后系统自动触发收货公司的inbound delivery单据,发货公司发出 ...

  7. Android为TV端助力:UDP协议(接收组播和单播)

    private static String MulticastHost="224.9.9.98";private static int POST=19999;private sta ...

  8. (最完美)MIUI12系统的Usb调试模式在哪里开启的步骤

    当我们使用安卓手机通过数据线链接到Pc的时候,或者使用的有些app比如我们公司营销小组当使用的app引号精灵,之前的老版本就需要开启usb调试模式下使用,现当新版本不需要了,如果手机没有开启usb调试 ...

  9. DVWA 黑客攻防演练(十)反射型 XSS 攻击 Reflected Cross Site Scripting

    XSS (Cross-site scripting) 攻击,为和 CSS 有所区分,所以叫 XSS.又是一种防不胜防的攻击,应该算是一种 "HTML注入攻击",原本开发者想的是显示 ...

  10. d3.csv()后获取的数据不是数组,而是对象

    我的csv文件: year,population 1953,5.94 1964,6.95 1982,10.08 1990,11.34 2000,12.66 2010,13.40 使用d3.csv()输 ...