对于Transformer模型的positional encoding,最初在Attention is all you need的文章中提出的是进行绝对位置编码,之后Shaw在2018年的文章中提出了相对位置编码,就是本篇blog所介绍的算法RPR;2019年的Transformer-XL针对其segment的特定,引入了全局偏置信息,改进了相对位置编码的算法,在相对位置编码(二)的blog中介绍。

本文参考链接:

1. Self-Attention with Relative Position Representations (Shaw et al.2018): https://arxiv.org/pdf/1803.02155.pdf

2. Attention is all you need (Vaswani et al.2017): https://arxiv.org/pdf/1706.03762.pdf

3. How Self-Attention with Relative Position Representations works: https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a

4. [NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL: https://www.cnblogs.com/shiyublog/p/11236212.html

Motivation

RNN中,第一个"I"与第二个"I"的输出表征不同,因为用于生成这两个单词的hidden states是不同的。对于第一个"I",其hidden state是初始化的状态;对于第二个"I",其hidden state是编码了"I think therefore"的hidden state。所以RNN的hidden state 保证了在同一个输入序列中,不同位置的同样的单词的output representation是不同的。

在self-attention中,第一个"I"与第二个"I"的输出将完全相同。因为它们用于产生输出的“input”是完全相同的。即在同一个输入序列中,不同位置的相同的单词的output representation完全相同,这样就不能提现单词之间的时序关系。--所以要对单词的时序位置进行编码表征。

概述

作者提出了在Transformer模型中加入可训练的embedding编码,使得output representatino可以表征inputs的时序信息。这些embedding vectors是 在计算输入序列中的任意两个单词$i, j$ 之间的attention weight 和 value时被加入到其中。embedding vector用于表示单词$i,j$之间的距离(即为间隔的单词数),所以命名为"相对位置表征" (Relative Position Representation) (RPR)

比如一个长度为5的序列,需要学习9个embeddings。(1个表示当前单词,4个表示其左边的单词,4个表示其右边的单词。)

以下例子展示了这些embeddings的用法:

1)

以上图示显示了计算第一个"I"的output representation的过程。箭头下面的数字显示了计算attention时用到的哪个RPRs.(比如,本示例是求第一个“I”的输出,需要用第一个“I”,记为''I_1',与sequence中每一个单词两两做self-attention运算。'I_1' with 'I_1'用到 index = 4 的RPR,“I_1”with 'think'用到index = 5 的RPR--因为是右边第一个, 'I_1' with 'therefore' 用到index = 6的RPR--因为是右边第二个... )

2)

与(1)同理。

符号含义

两点需要注意:

1. 有2个RPR的表征。需要在计算$z_i$和$e_{ij}$时分别引入对应的RPR的embedding。计算$z_i$时对应的RPR vector 是$a_{ij}^V$, 计算$e_{ij}时引入的RPR vector$是$a_{ij}^K$. 不同于在做multi-head attention时引入的线性映射矩阵W——对于每个head都不同;这个RPR embedding 在同一层的attention heads之间共享,但是在不同层的RPR可能不同。

2. 最大单词数被clipped在一个绝对的值k以内。向左k个, 再左边均为0, 向右k个,再右边均为k, 所表示的index范围: 2k + 1. 

eg. 10 words, k = 3, RPR embedding lookup table

设置k值截断的意义:

1. 作者假设精确的相对位置编码在超出了一定距离之后是没有必要的

2. 截断最大距离使得模型的泛化效果好,可以更好的Generalize到没有在训练阶段出现过的序列长度上。

之后,将分别学习key, value的相对位置表征。

$$w^{K} = (w_{-k}^K, ..., w_{k} ^K), w^{V} = (w_{-k}^V, ..., w_{k} ^V)$$

其中$w_i^K, w_i^V \in \mathbb{R}^{d_a}$.

实现

1. 若不使用RPR, 计算$z_i$的过程:

2. 若使用RPR,计算$z_i$的过程:

(3) 表示在计算word i 的output representation时,对于word j的value vector进行了修改,加上了word i, j 之间的相对位置编码。

(4) 在计算query(i), key(j)的点积时,对key vector进行了修改,加上了word i, j 之间的相对位置编码。

这里用加法引入RPR的信息,是一种高效的实现方式。

高效实现

不加RPR时,Transformer计算$e_{ij}$使用了 batch_size * h 个并行的矩阵乘法运算。

其中的x是给定input sequence后的(row-wise)

将(4) 式写为以下形式:

(1) 首先看第一项,$$x_iW^Q(x_jW^K)^T$$

首先看对于一个batch,的一个head, 其中$x_i$的shape是(seq_length, dx),现在假设seq_length = 1,来简化推导过程。假设$W^Q, W^K$的shape均为(dx, dz),那么第一项运算后的shape为:[(1 * dx) * (dx, dz)] * [(dz, dx) * (dx, 1)] = (1, 1),

这是对于一个batch,一个head, seq_length = 1的情况,那么扩充到真实的情况,其shape 为: (batch_size, h, seq_length, seq_length)

所以我们的目标是产生另一个有相同shape的tensor,其内容是word i 与关于Wordi, j 的RPR的embedding的点积。

(2) A.shape: (seq_length, seq_length, d_a),

$transpose \rightarrow A^T.shape: $(seq_length, d_a, seq_length)

(3) 第二项中的$x_i W^Q.shape:$ (batch_size, h, seq_length, d_z)

$transpose \rightarrow $ (seq_length, batch_size, h, d_z)

$reshape \rightarrow $ (seq_length, batch_size * h, d_z)

之后可以与$A^T$相乘,可以看做是seq_length个并行的(batch_size * h, d_z) matmul (d_a, seq_length),因为$d_z = d_a$,所以每个并行的运算结果是:(batch_size * h, seq_length), 总的大矩阵的shape: (seq_length, batchsize * h, seq_length).

$reshape \rightarrow $(seq_length, batch_size, h, seq_length)

$transpose \rightarrow$ (batch_size, h, seq_length, seq_length)

与第一项的shape一致,可以相加。

(3)式的推导同理。

下面给出tensor2tensor中对于相对位置编码的代码:https://github.com/tensorflow/tensor2tensor/blob/9e0a894034d8090892c238df1bd9bd3180c2b9a3/tensor2tensor/layers/common_attention.py#L1556-L1587

其中x,对应上面推导中的$x_i * W^Q$, y对应上面推导中的$x_j * W^K$, z对应上面的a。

 def _relative_attention_inner(x, y, z, transpose):
"""Relative position-aware dot-product attention inner calculation.
This batches matrix multiply calculations to avoid unnecessary broadcasting.
Args:
x: Tensor with shape [batch_size, heads, length or 1, length or depth].
y: Tensor with shape [batch_size, heads, length or 1, depth].
z: Tensor with shape [length or 1, length, depth].
transpose: Whether to transpose inner matrices of y and z. Should be true if
last dimension of x is depth, not length.
Returns:
A Tensor with shape [batch_size, heads, length, length or depth].
"""
batch_size = tf.shape(x)[0]
heads = x.get_shape().as_list()[1]
length = tf.shape(x)[2] # xy_matmul is [batch_size, heads, length or 1, length or depth]
xy_matmul = tf.matmul(x, y, transpose_b=transpose)
# x_t is [length or 1, batch_size, heads, length or depth]
x_t = tf.transpose(x, [2, 0, 1, 3])
# x_t_r is [length or 1, batch_size * heads, length or depth]
x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
# x_tz_matmul is [length or 1, batch_size * heads, length or depth]
x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
# x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
# x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
return xy_matmul + x_tz_matmul_r_t

结果

使用Attention is All You Need的机器翻译的任务。在training steos每秒去掉7%的条件下,模型的BLEU分数对于English-to-German最高提升了1.3, 对于English-to-French最高提升了0.5.

 [支付宝] 感谢您的捐赠!

But one thing I do: Forgetting what is behind and straining toward what is ahead. ~Bible.Philippians.

[NLP] 相对位置编码(一) Relative Position Representatitons (RPR) - Transformer的更多相关文章

  1. [NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL

    参考: 1. Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context https://arxiv.org/pdf ...

  2. 中文NER的那些事儿5. Transformer相对位置编码&TENER代码实现

    这一章我们主要关注transformer在序列标注任务上的应用,作为2017年后最热的模型结构之一,在序列标注任务上原生transformer的表现并不尽如人意,效果比bilstm还要差不少,这背后有 ...

  3. ICCV2021 | Vision Transformer中相对位置编码的反思与改进

    ​前言  在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...

  4. 13-[CSS]-postion位置:相relative,绝absolute,固fixed,static(默认),z-index

    1.postion位置属性 <!DOCTYPE html> <html lang="en"> <head> <meta charset=& ...

  5. 第五课第四周实验一:Embedding_plus_Positional_encoding 嵌入向量加入位置编码

    目录 变压器预处理 包 1 - 位置编码 1.1 - 位置编码可视化 1.2 - 比较位置编码 1.2.1 - 相关性 1.2.2 - 欧几里得距离 2 - 语义嵌入 2.1 - 加载预训练嵌入 2. ...

  6. Dedecms当前位置{dede:field name='position'/}修改

    这个实在list_article.htm模板出现的,而这个模板通过loadtemplage等等一系列操作是调用的include 下的arc.archives.class.php $this->F ...

  7. Dedecms当前位置{dede:field name='position'/}修改,去掉>方法

    Dedecms当前位置{dede:field name='position'/}修改,如何去掉> 一.修改{dede:field name='position'/}的文字间隔符,官方默认的是&g ...

  8. css背景图片位置:background的position(转)

    css背景图片位置:background的position   position的两个参数:水平方向的位置,垂直方向的位置----------该位置是指背景图片相对于前景对象的 1.backgroun ...

  9. DIV滚动条滚动到指定位置(jquery的position()与offset()方法区别小记)

    相对浏览器,将指定div滚到到指定位置,其用法如下 $("html,body").animate({scrollTop: $(obj).offset().top},speed); ...

随机推荐

  1. Ubuntu 下压缩软件的安装

    在ubuntu下,系统就自带一个压缩包管理软件,但是,它默认是不支持rar和7zip格式的.因此,我们可以给它直接“增强”一下.就成了万能的了.安装方法,终端里面: sudo apt-get inst ...

  2. java多线程之管道流

    java语言中提供了各种各样的流供我们操纵数据,其中管道流(pipeStream)是一种特殊的流,用于在不同线程间直接传送数据. 一个线程发送数据到输出管道,另一个线程从输入管道读取数据,通过使用管道 ...

  3. 打印第二列为oldboy的第一列内容(awk,grep,sed用法)

    [root@goldtest ~]# cat ip.log 10.0.0.1 oldboy 10.0.0.2 oldgirl 10.0.0.4 tingting 10.0.0.4 oldboy old ...

  4. IOS关于数据加密(主要为登录加密)想总结的

    首先上来就来说一下,IOS常见的几种加密算法  *哈希(散列)函数 : MD5.SHA  *对称加密算法:DES.3DES.AES  *非对称加密算法:RSA 一.哈希(散列)函数  1.MD5 MD ...

  5. 【Go】使用压缩文件优化io (一)

    原文连接:https://blog.thinkeridea.com/201906/go/compress_file_io_optimization1.html 最近遇到一个日志备份 io 过高的问题, ...

  6. python trojan development 2nd —— use python to send mail and listen to the key board then combine them

    请勿用于非法用途!!!!!本人概不负责!!!原创作品,转载说明出处!!!!! from pynput.keyboard import Key,Listener import logging impor ...

  7. Storm 学习之路(六)—— Storm项目三种打包方式对比分析

    一.简介 在将Storm Topology提交到服务器集群运行时,需要先将项目进行打包.本文主要对比分析各种打包方式,并将打包过程中需要注意的事项进行说明.主要打包方式有以下三种: 第一种:不加任何插 ...

  8. ParrotSec 中文社区 QQ群认证 Openssl解密

    ParrotSec 中文社区 QQ群认证 Openssl解密 下载Key.txt 打开parrot 系统,复制文件到系统.打开命令行输入 openssl enc -aes-256-cfb -d -in ...

  9. VUE、微信for动态变量取值(拼接取值)

    item.value是其它循的值如value=[1,2,3] {{'images[arrAy' + item.value+']'}} 那么拼接结果是 {{images[arrAy1]}}, {{ima ...

  10. 【React】react学习笔记01-概念与基本使用

      俺为啥要学这玩意:   家里的杂事好不容易处理完了,开始正式静下心来学习~博主是做后端开发的,js基础不深,之前也是用React写过许多东西了,但是基本上都是用的CV大法,别人的组 件修修改改拿来 ...