pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

其中test.txt的内容

As they sat in a nice coffee shop,
he was too nervous to say anything and she felt uncomfortable.
Suddenly, he asked the waiter,
"Could you please give me some salt? I'd like to put it in my coffee."

具体参见如下代码

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq vocab = {}
token_id = 1
lengths = [] #读取文件,生成词典
with open('test.txt', 'r') as f:
lines=f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
lengths.append(len(tokens))
#将每个词加入到vocab中,并同时保存对应的index
for word in tokens:
if word not in vocab:
vocab[word] = token_id
token_id += 1 x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
lines = f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
for i in range(len(tokens)):
x[l_no, i] = vocab[tokens[i]]
l_no += 1 x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11., 5., 14.]) _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11., 8., 5.])
print(idx_sort)#tensor([3, 1, 0, 2]) lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24., 9., 1., 20., 25., 10., 2., 9., 26., 11., 3., 21., 27., 12.,
4., 22., 28., 13., 5., 23., 29., 14., 6., 30., 15., 7., 31., 16.,
8., 32., 17., 13., 18., 33., 19., 34., 4., 7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
''' x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([14, 11, 8, 5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''

pytorch中的pack_padded_sequence和pad_packed_sequence用法的更多相关文章

  1. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  4. PyTorch中view的用法

    相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...

  5. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  6. [PyTorch]PyTorch中反卷积的用法

    文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. pytorch中如何使用DataLoader对数据集进行批处理

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...

  9. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

随机推荐

  1. Python爬虫技术:爬虫时如何知道是否代理ip伪装成功?

    前言本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. python爬虫时如何知道是否代理ip伪装成功: 有时候我们的爬虫程序添加了 ...

  2. 国内下载vscode速度慢解决

    找到对应的文件,点击下载,会出现一个类似下面的链接: https://az764295.vo.msecnd.net/stable/f06011ac164ae4dc8e753a3fe7f9549844d ...

  3. Jmeter查看结果树之查看响应的13种方法[详解]

    查看结果树查看响应有哪几种方法,可通过左侧面板底部的下拉框选择 1.Text 查看结果树中请求的默认格式为text,会显示请求的取样器结果.请求.响应数据3个部分内容. 取样器结果: 默认Raw视图, ...

  4. Wireshark数据包分析入门

    Wireshark数据包分析(一)——使用入门   Wireshark简介: Wireshark是一款最流行和强大的开源数据包抓包与分析工具,没有之一.在SecTools安全社区里颇受欢迎,曾一度超越 ...

  5. NodeJS3-1基础API----Path(路径)

    path 和路径有关的操作 Path(路径)  path 模块提供用于处理文件路径和目录路径的实用工具. 它可以使用以下方式访问 const path = require('path');  1. p ...

  6. django基础之day08,分页器从无到有,动态思路解析全过程

    *********分页器从无到有的全过程,动态思路解析如下:******** 1.通过book_queryset = models.Book.objects.all()[start_num:end_n ...

  7. 《Java基础知识》一维,二维数组的申明和使用

    为什么要使用数组: 因为不使用数组计算多个变量的时候太繁琐,不利于数据的处理. --------   数组也是一个变量,是存储一组相同类型的变量 声明一个变量就是在内存中划出一块合适的空间 声明一个数 ...

  8. 面试连环炮系列(二十三): StringBuffer与StringBuild的区别

    StringBuffer与StringBuild的区别 频繁修改字符串时,建议使用StringBuffer和StringBuilder类.StringBuilder相较于StringBuffer有速度 ...

  9. CMAKE同时编译C++和CUDA文件

    1. 首先是运行环境 Ubuntu 16.04 G++ 5.4.0 CUDA 8.0 2. 文件结构 cv@cv:~/myproject$ tree src src/ |-- CMakeLists.t ...

  10. 学习python这么久,有没有考虑发布一个属于自己的模块?

    ​ 1. 为什么需要对项目分发打包? 平常我们习惯了使用 pip 来安装一些第三方模块,这个安装过程之所以简单,是因为模块开发者为我们默默地为我们做了所有繁杂的工作,而这个过程就是 打包. 打包,就是 ...