pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

其中test.txt的内容

As they sat in a nice coffee shop,
he was too nervous to say anything and she felt uncomfortable.
Suddenly, he asked the waiter,
"Could you please give me some salt? I'd like to put it in my coffee."

具体参见如下代码

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq vocab = {}
token_id = 1
lengths = [] #读取文件,生成词典
with open('test.txt', 'r') as f:
lines=f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
lengths.append(len(tokens))
#将每个词加入到vocab中,并同时保存对应的index
for word in tokens:
if word not in vocab:
vocab[word] = token_id
token_id += 1 x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
lines = f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
for i in range(len(tokens)):
x[l_no, i] = vocab[tokens[i]]
l_no += 1 x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11., 5., 14.]) _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11., 8., 5.])
print(idx_sort)#tensor([3, 1, 0, 2]) lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24., 9., 1., 20., 25., 10., 2., 9., 26., 11., 3., 21., 27., 12.,
4., 22., 28., 13., 5., 23., 29., 14., 6., 30., 15., 7., 31., 16.,
8., 32., 17., 13., 18., 33., 19., 34., 4., 7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
''' x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([14, 11, 8, 5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''

pytorch中的pack_padded_sequence和pad_packed_sequence用法的更多相关文章

  1. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  4. PyTorch中view的用法

    相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...

  5. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  6. [PyTorch]PyTorch中反卷积的用法

    文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. pytorch中如何使用DataLoader对数据集进行批处理

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...

  9. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

随机推荐

  1. Seata 配置中心实现原理

    Seata 可以支持多个第三方配置中心,那么 Seata 是如何同时兼容那么多个配置中心的呢?下面我给大家详细介绍下 Seata 配置中心的实现原理. 配置中心属性加载 在 Seata 配置中心,有两 ...

  2. 【Eureka】服务端和客户端

    [Eureka]服务端和客户端 转载:https://www.cnblogs.com/yangchongxing/p/10778357.html Eureka服务端 1.添加依赖 <?xml v ...

  3. 1篇文章搞清楚8种JVM内存溢出(OOM)的原因和解决方法

    前言 撸Java的同学,多多少少会碰到内存溢出(OOM)的场景,但造成OOM的原因却是多种多样. 堆溢出 这种场景最为常见,报错信息: java.lang.OutOfMemoryError: Java ...

  4. 中国剩余定理(CRT)及其拓展(ExCRT)

    中国剩余定理 CRT 推导 给定\(n\)个同余方程 \[ \left\{ \begin{aligned} x &\equiv a_1 \pmod{m_1} \\ x &\equiv ...

  5. 表达式和运算符知识总结(js)

    文章目录: 一. 表达式和语句的区别 二. 自增自减运算符的运算规则 一. 表达式和语句的区别 表达式(expression)是JavaScript中的一个短语,JavaScript解释器会将其计算( ...

  6. AOP框架Dora.Interception 3.0 [3]: 拦截器设计

    对于所有的AOP框架来说,多个拦截器最终会应用到某个方法上.这些拦截器按照指定的顺序构成一个管道,管道的另一端就是针对目标方法的调用.从设计角度来将,拦截器和中间件本质是一样的,那么我们可以按照类似的 ...

  7. windows cmd 生成文件目录树

    一.背景 之前逛GitHub的时候看到有大佬在描述项目结构的时候使用了一种文件目录树的格式 │ └─student_information_management_system │ │ ├─build ...

  8. ASP.NET+d3.js实现Sqlserver数据库的可视化展示

    效果: 数据库端: 前端展示: 实现原理: 1.在数据段建立两个存储过程 queryUserAnsawer(id) 根据用户ID返回每一题的得分,主要是bcp exe时不能直接在sqlserver中执 ...

  9. Springboot Activiti6 工作流 集成代码生成器 shiro 权限 vue.js html 跨域 前后分离

    官网:www.fhadmin.org 特别注意: Springboot 工作流  前后分离 + 跨域 版本 (权限控制到菜单和按钮) 后台框架:springboot2.1.2+ activiti6.0 ...

  10. oracle中add_months()函数总结

    今天对add_months函数进行简单总结一下: add_months 函数主要是对日期函数进行操作,在数据查询的过程中进行日期的按月增加,其形式为: add_months(date,int);其中第 ...