pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

其中test.txt的内容

As they sat in a nice coffee shop,
he was too nervous to say anything and she felt uncomfortable.
Suddenly, he asked the waiter,
"Could you please give me some salt? I'd like to put it in my coffee."

具体参见如下代码

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq vocab = {}
token_id = 1
lengths = [] #读取文件,生成词典
with open('test.txt', 'r') as f:
lines=f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
lengths.append(len(tokens))
#将每个词加入到vocab中,并同时保存对应的index
for word in tokens:
if word not in vocab:
vocab[word] = token_id
token_id += 1 x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
lines = f.readlines()
for line in lines:
tokens = wordfreq.tokenize(line.strip(), 'en')
for i in range(len(tokens)):
x[l_no, i] = vocab[tokens[i]]
l_no += 1 x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11., 5., 14.]) _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11., 8., 5.])
print(idx_sort)#tensor([3, 1, 0, 2]) lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24., 9., 1., 20., 25., 10., 2., 9., 26., 11., 3., 21., 27., 12.,
4., 22., 28., 13., 5., 23., 29., 14., 6., 30., 15., 7., 31., 16.,
8., 32., 17., 13., 18., 33., 19., 34., 4., 7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
''' x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([14, 11, 8, 5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 0., 0., 0., 0., 0., 0.],
[ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 0., 0., 0.],
[20., 9., 21., 22., 23., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34., 4., 7.]])
'''

pytorch中的pack_padded_sequence和pad_packed_sequence用法的更多相关文章

  1. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  2. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  3. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  4. PyTorch中view的用法

    相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...

  5. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  6. [PyTorch]PyTorch中反卷积的用法

    文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. pytorch中如何使用DataLoader对数据集进行批处理

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处 ...

  9. PyTorch中使用深度学习(CNN和LSTM)的自动图像标题

    介绍 深度学习现在是一个非常猖獗的领域 - 有如此多的应用程序日复一日地出现.深入了解深度学习的最佳方法是亲自动手.尽可能多地参与项目,并尝试自己完成.这将帮助您更深入地掌握主题,并帮助您成为更好的深 ...

随机推荐

  1. Python3 并发编程2

    目录 进程互斥锁 基本概念 互斥锁的使用 IPC 基本概念 队列 生产者消费者模型 基本概念 代码实现 线程 基本概念 创建线程 线程互斥锁 进程互斥锁 基本概念 临界资源: 一次仅允许一个进程使用的 ...

  2. 纯净版SSM

    pom.xml <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w ...

  3. Electron 设置 -webkit-app-region 后无法响应鼠标点击事件的解决方式

    参考博客:https://blog.csdn.net/qq_20264891/article/details/87721163

  4. 深入浅出Object.defineProperty()

    深入浅出Object.defineProperty() 红宝书对应知识点页码:139页 红宝书150页:hasOwnProperty( )方法可以检测一个属性是存在于实例中,还是存在于原型中,给定属性 ...

  5. python爬虫--爬虫介绍

    一 爬虫 1.什么是互联网? 互联网是由网络设备(网线,路由器,交换机,防火墙等等)和一台台计算机连接而成,像一张网一样 2.互联网建立的目的? 互联网的核心价值在于数据的共享/传递:数据是存放于一台 ...

  6. Spring boot采坑记--- 在启动时RequstMappingHandlerMapping无法找到部分contorller类文件的解决方案

    最近有一个心得需求,需要在一个现有的springboot项目中增加一些新的功能,于是就在controller文件包下面创建新的包和类文件,但是后端开发完之后,本地测试发现前端访问报404错误,第一反应 ...

  7. 在 Windows 10 上搭建 Cordova 跨平台开发 Android 环境

    目录 安装 Cordova 安装 Java 和 Android 环境 创建 Cordova 应用程序 构建和运行 Cordova Cordova 简介:Cordova 原名 PhoneGap,是一个开 ...

  8. [ASP.NET Core 3框架揭秘] 跨平台开发体验: Linux

    如果想体验Linux环境下开发.NET Core应用,我们有多种选择.一种就是在一台物理机上安装原生的Linux,我们可以根据自身的喜好选择某种Linux Distribution,目前来说像RHEL ...

  9. JPA中实现双向一对多的关联关系

    场景 JPA入门简介与搭建HelloWorld(附代码下载): https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/103473937 ...

  10. 传统js和jsx与ts和tsx的区别

    一.从定义文件格式方面说 1.传统的开发模式可以定义js文件或者jsx文件2.利用ts开发定义的文件格式tsx二.定义state的状态来说 1.传统的方式直接在构造函数中使用 constructor( ...