【pytorch】pytorch-LSTM
pytorch-LSTM()
torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。
LSTM自动实现了前向传播,不需要自己对序列进行迭代。
LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数
input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1
为了统一,以后都batch_first=True
LSTM的输入为:LSTM(input,(h0,co))
其中,指定batch_first=True后,input就是(batch_size,seq_len,input_size)
(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)
其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)
c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度
LSTM的输出为:out,(hn,cn)
其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)
而hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。
应用LSTM
创建一LSTM:
lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)
forward使用LSTM层:
out,hidden = lstm(input,hidden)
其中,hidden=(h0,c0)是个tuple
最终得到out,hidden
举例:
import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)
def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden
def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size))
input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)
【pytorch】pytorch-LSTM的更多相关文章
- 【翻译】理解 LSTM 网络
目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...
- 【翻译】理解 LSTM 及其图示
目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- 【转载】Pytorch tutorial 之Datar Loading and Processing
前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...
- 【转载】 pytorch笔记:06)requires_grad和volatile
原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...
- 【转载】 Pytorch 细节记录
原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...
- 【转载】 pytorch之添加BN
原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...
- 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?
原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...
- 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau
原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...
- 【转载】 PyTorch学习之六个学习率调整策略
原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...
随机推荐
- Spring Boot使用Spring Data Jpa对MySQL数据库进行CRUD操作
只需两步!Eclipse+Maven快速构建第一个Spring Boot项目 构建了第一个Spring Boot项目. Spring Boot连接MySQL数据库 连接了MySQL数据库. 本文在之前 ...
- Java 加密、解密PDF文档
本篇文章将介绍通过Java编程来设置PDF文档保护的方法.我们可以设置仅用于查阅文档的密码,即该通过该密码打开文档仅用于文档阅读,无法编辑:也可以设置文档编辑权限的密码,即通过该密码打开文档时,文档为 ...
- Spring MVC深入学习
一.MVC思想 MVC思想简介: MVC并不是java所特有的设计思想,也不是Web应用所特有的思想,它是所有面向对象程序设计语言都应该遵守的规范:MVC思想将一个应用部分分成三个基本部 ...
- JavaScript 运行机制 (Event Loop)
单线程就意味着,所有任务需要排队,前一个任务结束,才会执行后一个任务.如果前一个任务耗时很长,后一个任务就不得不一直等着. 所有任务可以分成两种,一种是同步任务(synchronous),另一种是异步 ...
- 解决PostGIS打开shp文件输入输出模块出现"找不到文件libintl-9.dll"的问题
找到shp2pgsql-gui.exe这个程序的目录 复制一份libintl-8.dll副本,改名为libintl-9.dll即可.
- Python 基于Python及zookeeper实现简单分布式任务调度系统设计思路及核心代码实现
基于Python及zookeeper实现简单分布式任务调度系统设计思路及核心代码实现 by:授客 QQ:1033553122 测试环境 功能需求 实现思路 代码实践(关键技术点实现) 代码模块组织 ...
- 小米平板7.0系统如何不root激活Xposed框架的方法
在越来越多公司的引流或业务操作中,基本都需要使用安卓的强大XPOSED框架,这段时间我们公司买来了一批新的小米平板7.0系统,基本都都是基于7.0以上版本,基本都不能够获取root超级权限,即使小部分 ...
- Android视频录制从不入门到入门系列教程(二)————显示视频图像
1.创建一个空的工程,注意声明下列权限: <uses-permission android:name="android.permission.CAMERA"/> < ...
- FPGA设计千兆以太网MAC(3)——数据缓存及位宽转换模块设计与验证
本文设计思想采用明德扬至简设计法.上一篇博文中定制了自定义MAC IP的结构,在用户侧需要位宽转换及数据缓存.本文以TX方向为例,设计并验证发送缓存模块.这里定义该模块可缓存4个最大长度数据包,用户根 ...
- anaconda 环境新建/删除/拷贝 jupyter notebook上使用python虚拟环境 TensorFlow
naconda修改国内镜像源 国外网络有时太慢,可以通过配置把下载源改为国内的通过 conda config 命令生成配置文件,这里使用清华的镜像: https://mirrors.tuna.tsin ...