Iterators

对torchtext的batch实现的修改算法原理

Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.

这里是对torchtext中默认的batching操作进行的优化修改。

参考:https://towardsdatascience.com/how-to-use-torchtext-for-neural-machine-translation-plus-hack-to-make-it-5x-faster-77f3884d95

Torchtext本身已经很好了,并且sort_key使得dataset中的数据排序,这样batching后序列长度相近的会被放在同一个batch中,可以很大程度上降低padding的个数。

但是下面代码又进行了优化:根据每个batch中序列的最大长度,动态更改batch_size,使得可以更好的利用计算资源。

举个例子:

假设你的RAM每个iteration可以处理1500个tokens, batch_size = 20, 那么只有当batch中的序列长度为sequence length = 1500 / 20 = 75时,才可以将计算资源利用完全。

现实中,每个batch的sequence length的显然是在变化的,那么如果希望尽量多的利用计算资源,就需要可以动态调整当前的batch_size.

Transformer中的MyIterator重载了data.Iterator中的create_batches函数:

 class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler) else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key)) def rebatch(pad_idx, batch):
"Fix order in torchtext to match ours"
src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
return Batch(src, trg, pad_idx)

pool函数

其中pool函数的功能与https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py中定义的class BucketIterator(Iterator)的pool函数功能类似。

1. 将原始的data分成大小为 100 * batch_size的一些chunks => (以上迭代 p 即为 每个chunk)

2. 在每个chunk中根据 sort_key 对examples进行排序,并对每个chunk按照batch_size分成100个batch =>

( p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) )

3. 将这些chunks进行shuffle  => (random_shuffler(list(p_batch)))

4. 在每个chunk中再把examples分成 大小为 batch_size 的 100 个 batch => (以上 b 即为每个 batch)

5. 生成器每次 yield一个batch  => (yield b)

[The Annotated Transformer] Iterators的更多相关文章

  1. [NLP] The Annotated Transformer 代码修正

    1. RuntimeError: "exp" not implemented for 'torch.LongTensor' class PositionalEncoding(nn. ...

  2. 【译】图解Transformer

    目录 从宏观上看Transformer 把张量画出来 开始编码! 从宏观上看自注意力 自注意力的细节 自注意力的矩阵计算 "多头"自注意力 用位置编码表示序列的顺序 残差 解码器 ...

  3. 深入理解Transformer及其源码解读

    深度学习广泛应用于各个领域.基于transformer的预训练模型(gpt/bertd等)基本已统治NLP深度学习领域,可见transformer的重要性.本文结合<Attention is a ...

  4. Transformer模型---encoder

    一.简介 论文链接:<Attention is all you need> 由google团队在2017年发表于NIPS,Transformer 是一种新的.基于 attention 机制 ...

  5. zz全面拥抱Transformer

    放弃幻想,全面拥抱Transformer:自然语言处理三大特征抽取器(CNN/RNN/TF)比较 在辞旧迎新的时刻,大家都在忙着回顾过去一年的成绩(或者在灶台前含泪数锅),并对2019做着规划,当然也 ...

  6. Transformer的PyTorch实现

    Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简!该论文提出了Transformer模型,完全基于Attention mechanism,抛弃 ...

  7. 动手学Transformer

    动手实现Transformer,所有代码基于tensorflow2.0,配合illustrated-transformer更香. 模型架构 Encoder+Decoder Encoder Decode ...

  8. 深入浅出Transformer

    Transformer Transformer是NLP的颠覆者,它创造性地用非序列模型来处理序列化的数据,而且还获得了大成功.更重要的是,NLP真的可以"深度"学习了,各种基于tr ...

  9. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

随机推荐

  1. redis集群搭建和哨兵模式以及AOF和RDB持久化

    Redis主从+哨兵模式 1.环境准备 (1)三台独立的linux主机 (2)IP分别为:10.150.200.182 (从) 10.150.200.184(从)  10.150.200.195(主) ...

  2. Java开发者想尝试转行大数据,学习方向建议?

      ​前言 相信很多Java开发者都对大数据有一定的了解,随着大数据时代的到来,也有很多Java程序员想要转行大数据.大数据技术中大多数平台使用的都是Java语言,因此,对于大数据技术的学习来说,Ja ...

  3. 帝国cms 获取一条数据,但是从第二条开始获取

    /*这里的1指的是获取一条数据,2指的是从第二条开始获取*/ [e:loop={"select * from phome_ecms_news where classid='2' limit ...

  4. PostgreSQL 按照日期范围查询

    method 1 select * from user_info where create_date >= '2019-05-01' and create_date < '2019-08- ...

  5. 无线传输模块HC-12

    无线传输模块HC-12使用 因为实验室的无人机需要使用一款无线传输模块进行遥控控制,我们讨论的中测试了HC-12,并对HC-12传输距离进行了简单测试.在此做下使用记录. 模块概述 HC-12 无线串 ...

  6. redis加入systemctl服务

    来自:https://blog.csdn.net/weixin_41114593/article/details/82383716 第一步 安装redis去官网下载最新的redis版本    安装官网 ...

  7. 在mysql 上如何在不影响生产的情况下删除一个大表

    mysql 中常用的删除的方法基本上有下面三种方式: 1.delete 一般用于删除少量表中的数据 优化建议,一定要加上where 条件,并且where条件的列上 一定要有主键或者索引.否则会出现全表 ...

  8. Delphi 数组与记录类型

  9. 网络初级篇之OSPF(二)实验

    一.实验目的:     下面关于OSPF的实验,仔细看配置过程,以增加对OSPF的理解. 二.实现目标:     使用OSPF实现所有主机之间的通信 三.配置过程: 1.AR1的配置过程:      ...

  10. 〇二——body内标签之交互输入标签二

    我们在上一章讲了一些要通过后台程序实现交互的标签,下面我们看一看一些不通过后台实现交互的标签. 一.a标签 a标签主要实现超链接的功能 1.跳转效果 这个效果比较简单,直接在属性里添加一个网址,然后可 ...