主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

  • 一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)
  • 一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,Bbatch size*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable

参数说明:

  • input (Variable) – 变长序列 被填充后的 batch

  • lengths (list[int]) – Variable 中 每个序列的长度。

  • batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的sizeT×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*

Batch中的元素将会以它们长度的逆序排列。

参数说明:

  • sequence (PackedSequence) – 将要被填充的 batch

  • batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1 tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step # pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True) # initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size)) #forward
out, _ = rnn(pack, h0) # unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
0.5406 0.3584
-0.1403 0.0308 (1 ,.,.) =
-0.6855 -0.9307
0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

pytorch对可变长度序列的处理的更多相关文章

  1. 使用PyTorch建立你的第一个文本分类模型

    概述 学习如何使用PyTorch执行文本分类 理解解决文本分类时所涉及的要点 学习使用包填充(Pack Padding)特性 介绍 我总是使用最先进的架构来在一些比赛提交模型结果.得益于PyTorch ...

  2. PyTorch专栏(六): 混合前端的seq2seq模型部署

    欢迎关注磐创博客资源汇总站: http://docs.panchuang.net/ 欢迎关注PyTorch官方中文教程站: http://pytorch.panchuang.net/ 专栏目录: 第一 ...

  3. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

  4. 混合前端seq2seq模型部署

    混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...

  5. [源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3)

    [源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3) 目录 [源码解析] NVIDIA HugeCTR,GPU版本参数服务器---(3) 0x00 摘要 0x01 回顾 0x0 ...

  6. http2协议翻译(转)

    超文本传输协议版本 2 IETF HTTP2草案(draft-ietf-httpbis-http2-13) 摘要 本规范描述了一种优化的超文本传输协议(HTTP).HTTP/2通过引进报头字段压缩以及 ...

  7. 【HTTP 2】HTTP/2 协议概述(HTTP/2 Protocol Overview)

    前情提要 在上一篇文章<[HTTP 2.0] 简介(Introduction)>中,我们简单介绍了 HTTP 2. 在本篇文章中,我们将会了解到 HTTP 2 协议概述部分的内容. HTT ...

  8. 论文翻译_Tracking The Untrackable_Learning To Track Multiple Cues with Long-Term Dependencies_IEEE2017

    Tracking The Untrackable: Learning to Track Multiple Cues with Long-Term Dependencies 跟踪不可跟踪:学习跟踪具有长 ...

  9. SCALA-基础知识学习(一)

    概述 本人开始学习scala的时候,是在使用和开发spark程序的时候,在此为了整理.记录和分享scala的基础知识,我写这篇关于scala的基础知识,希望与广大读者共同学习沟通进步.如果有些代码比较 ...

随机推荐

  1. 使用maven时出现Failure to transfer 错误的解决方法

    在eclipse里使用maven,连接nexus私服. 添加依赖之后,总是报添加的依赖jar文件找不到,但是在nexus的库里面能找到这个依赖的jar文件,但是在本地的maven库里面找不到,于是我将 ...

  2. 开发工具IntelliJ IDEA的安装步骤及首次启动和创建项目

    开发工具IDEA概述 DEA是一个专门针对Java的集成开发工具(IDE),由Java语言编写.所以,需要有JRE运行环境并配置好环境变量.它可以极大地提升我们的开发效率.可以自动编译,检查错误.在公 ...

  3. SqlServer2008_r2安装功能选择

    勾上数据引擎服务.客户端工具链接.sdk.管理工具.客户连接SDK.最后一个 sql2008安装时,怎么选择服务账户NT Authority\System ,系统内置账号,对本地系统拥有完全控制权限: ...

  4. C Programming Style 总结

    对材料C Programming Style for Engineering Computation的总结. 原文如下: C Programming Style for Engineering Com ...

  5. vue2.0 子组件和父组件之间的传值(转载)

    Vue是一个轻量级的渐进式框架,对于它的一些特性和优点在此就不做赘述,本篇文章主要来探讨一下Vue子父组件通信的问题 首先我们先搭好开发环境,我们首先得装好git和npm这两个工具(如果有不清楚的同学 ...

  6. 面试题(一续Spring)

    9.Spring体系结构和jar用途 参考https://blog.csdn.net/sunchen2012/article/details/53939253 spring官网给出了一张spring3 ...

  7. Python——爬虫——爬虫的原理与数据抓取

    一.使用Fiddler抓取HTTPS设置 (1)菜单栏 Tools > Telerik Fiddler Options 打开“Fiddler Options”对话框 (2)HTTPS设置:选中C ...

  8. vhdl 数据类型转换 使用IEEE标准库numeric_std 需要进行两次转换 use ieee.numeric_std.all;

    use ieee.numeric_std.all;

  9. luogu4770 [NOI2018]你的名字 (SAM+主席树)

    对S建SAM,拿着T在上面跑 跑的时候不仅无法转移要跳parent,转移过去不在范围内也要跳parent(注意因为范围和长度有关,跳的时候应该把长度一点一点地缩) 这样就能得到对于T的每个前缀,它最长 ...

  10. elasticsearch5之Elastalert 安装使用 配置邮件报警和微信报警

    简介 Elastalert是用python2写的一个报警框架(目前支持python2.6和2.7,不支持3.x),github地址为 https://github.com/Yelp/elastaler ...