pytorch中的pack_padded_sequence和pad_packed_sequence用法
pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。
pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。
其中test.txt的内容
As they sat in a nice coffee shop,
he was too nervous to say anything and she felt uncomfortable.
Suddenly, he asked the waiter,
"Could you please give me some salt? I'd like to put it in my coffee."
具体参见如下代码
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq vocab = {}
token_id = 1
lengths = [] #读取文件,生成词典
with open('test.txt', 'r') as f:
lines=f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
lengths.append(len(tokens))
#将每个词加入到vocab中,并同时保存对应的index
for word in tokens:
if word not in vocab:
vocab[word] = token_id
token_id += 1 x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
lines = f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
for i in range(len(tokens)):
x[l_no, i] = vocab[tokens[i]]
l_no += 1 x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11., 5., 14.]) _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11., 8., 5.])
print(idx_sort)#tensor([3, 1, 0, 2]) lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24., 9., 1., 20., 25., 10., 2., 9., 26., 11., 3., 21., 27., 12.,
4., 22., 28., 13., 5., 23., 29., 14., 6., 30., 15., 7., 31., 16.,
8., 32., 17., 13., 18., 33., 19., 34., 4., 7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
''' x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([14, 11, 8, 5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
pytorch中的pack_padded_sequence和pad_packed_sequence用法的更多相关文章
- [转载]PyTorch中permute的用法
[转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...
- Pytorch中randn和rand函数的用法
Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...
- Pytorch中nn.Conv2d的用法
Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...
- PyTorch中view的用法
相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
- [PyTorch]PyTorch中反卷积的用法
文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...
- pytorch中如何处理RNN输入变长序列padding
一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...
- pytorch中如何使用DataLoader对数据集进行批处理
最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...
- PyTorch中使用深度学习(CNN和LSTM)的自动图像标题
介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...
随机推荐
- 深入探讨多态性及其在Java中的好处
多态是面向对象软件的基本原理之一.该术语通常表示可以具有多种形式的事物.在面向对象的方法中,多态使编写具有后期绑定引用的程序成为可能.尽管在Java中创建多态引用很容易,但其背后的概念对整体编程产生了 ...
- python基础入门while循环 格式化 编码初识
一.while循环 1.格式 while+空格+条件+英文冒号: 缩进+结果(循环体) #若条件为真则一直执行,条件为假则不执行 while True: print('痒') print('. ...
- C语言每日一练——第3题
一.题目要求 程序功能:计算100以内满足以下条件的所有整数i的个数cnt以及这些i之和sum.条件:i, i+4 ,i+10都是素数,同时i+10小于100.最后电影函数writeDAT()函数把结 ...
- TrueTime的安装、运行例程
一.前言 Truetime的安装是为了完成课程相关需求,但在安装过程中遇到一些问题,想到自己之前注册了博客所以打算把这个作为第一篇的内容.请放心这个的安装过程并不困难,可以放心食用. 二.准备 Tru ...
- Sql将一列数据拆分为多行显示的两种方法
原始数据与期望结果有表tb, 如下:id value----------- -----------1 aa,bb2 aaa,bbb,ccc欲按 ...
- Sqlite—查询语句(Select)
基本语法如下 sqlite> select * from tb_user; sqlite> select userid,username from tb_user; 格式化的查询输出 sq ...
- GitHub 设置和取消代理,加速 git clone
git 设置代理: git config --global git 取消代理: git config --global --unset http.proxy 针对 github.com 设置代理: g ...
- [洛谷P1967][题解]货车运输
题目 这道题让我们求最小限重的最大值 显然可以先求出最大生成树,然后在树上进行操作 因为如果两点之间有多条路径的话一定会走最大的,而其他小的路径是不会被走的 然后考虑求最小权值 可以采用倍增求LCA, ...
- 比较3个开源数据库:PostgreSQL,MariaDB和SQLite
在现代企业技术世界里,开源软件已牢固地确立了自己作为不可忽视的,最大力量之一的地位.由于开源运动的出现,推动了几十年来的一些最著名的技术发展. 不难理解为什么:尽管基于Linux的开源网络标准可能不像 ...
- Comet OJ - Contest #11 B题 usiness
###题目链接### 题目大意:一开始手上有 0 个节点,有 n 天抉择,m 种方案,在每天中可以选择任意种方案.任意次地花费 x 个节点(手上的节点数不能为负),使得在 n 天结束后,获得 y 个节 ...