pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. #3 Python面向对象(二)

    前言 上一节主要记录面向对象编程的思想以及Python类的简单创建,这节继续深入类中变量的相关知识,Here we go! Python中类的各种变量 1.1 类变量 类变量定义:在类中,在函数体(方 ...

  2. Jquery简单学习

    Jquery是一个JavaScript的函数库,Jquery是一个写得少但做的多的轻量级JavaScript库 Jquery用美元$定义. Jquery的action执行对元素的操作 文档就绪函数: ...

  3. Java 处理 multipart/mixed 请求

    一.multipart/mixed 请求   multipart/mixed 和 multipart/form-date 都是多文件上传的格式.区别在于,multipart/form-data 是一种 ...

  4. ArcPy 创建图层空间索引

    使用Python脚本进行图层的空间索引的创建. 附上Python代码: # -*- coding: utf-8 -*- # nightroad import sys import arcpy relo ...

  5. ASP.Net Core开发(踩坑)指南

    ASP.NET与ASP.NET Core很类似,但它们之间存在一些细微区别以及ASP.NET Core中新增特性的使用方法,在此之前也写过一篇简单的对比文章ASP.NET MVC应用迁移到ASP.NE ...

  6. SQL学习笔记---非select操作

    非select命令 数据库 1.创建     //create database 库名 2.删除     //drop database 库名,... 2.重命名//exec sp_renamedb ...

  7. JS直接调用C#后台方法(ajax调用)

    1. 先手动引用DLL或者通过NuGet查找引用,这里提供一个AjaxPro.2.dll的下载: 2. 之后的的过程不想写了,网上都大同小异的,直接参考以前大佬写的: AjaxPro2完整入门教程 总 ...

  8. lua_local变量在new时不会被清空

    前言 我的运行环境 Lua5.3 按照我们以往的Java或C#编程经验,如果一个class被new,那么这个class中所有成员变量的值都是默值或是构造函数中赋的值,但在Lua中的local变量却并不 ...

  9. LeetCode算法题-Flood Fill(Java实现)

    这是悦乐书的第306次更新,第325篇原创 01 看题和准备 今天介绍的是LeetCode算法题中Easy级别的第173题(顺位题号是733).图像由二维整数数组表示,每个整数表示图像的像素值(从0到 ...

  10. kerberos环境下spark消费kafka写入到Hbase

    一.准备环境: 创建Kafka Topic和HBase表 1. 在kerberos环境下创建Kafka Topic 1.1 因为kafka默认使用的协议为PLAINTEXT,在kerberos环境下需 ...