导读

目前采用编码器-解码器 (Encode-Decode) 结构的模型非常热门,是因为它在许多领域较其他的传统模型方法都取得了更好的结果。这种结构的模型通常将输入序列编码成一个固定长度的向量表示,对于长度较短的输入序列而言,该模型能够学习出对应合理的向量表示。然而,这种模型存在的问题在于:当输入序列非常长时,模型难以学到合理的向量表示

在这篇博文中,我们将探索加入LSTM/RNN模型中的attention机制是如何克服传统编码器-解码器结构存在的问题的。

通过阅读这篇博文,你将会学习到:

  • 传统编码器-解码器结构存在的问题及如何将输入序列编码成固定的向量表示;
  • Attention机制是如何克服上述问题的,以及在模型输出时是如何考虑输出与输入序列的每一项关系的;
  • 基于attention机制的LSTM/RNN模型的5个应用领域:机器翻译、图片描述、语义蕴涵、语音识别和文本摘要。

让我们开始学习吧。

一、长输入序列带来的问题

使用传统编码器-解码器的RNN模型先用一些LSTM单元来对输入序列进行学习,编码为固定长度的向量表示;然后再用一些LSTM单元来读取这种向量表示并解码为输出序列。

采用这种结构的模型在许多比较难的序列预测问题(如文本翻译)上都取得了最好的结果,因此迅速成为了目前的主流方法。

例如:

这种结构在很多其他的领域上也取得了不错的结果。然而,它存在一个问题在于:输入序列不论长短都会被编码成一个固定长度的向量表示,而解码则受限于该固定长度的向量表示

这个问题限制了模型的性能,尤其是当输入序列比较长时,模型的性能会变得很差(在文本翻译任务上表现为待翻译的原始文本长度过长时翻译质量较差)。

“一个潜在的问题是,采用编码器-解码器结构的神经网络模型需要将输入序列中的必要信息表示为一个固定长度的向量,而当输入序列很长时则难以保留全部的必要信息(因为太多),尤其是当输入序列的长度比训练数据集中的更长时。”

— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015

二、使用Attention机制

Attention机制的基本思想是,打破了传统编码器-解码器结构在编解码时都依赖于内部一个固定长度向量的限制

Attention机制的实现是通过保留LSTM编码器对输入序列的中间输出结果,然后训练一个模型来对这些输入进行选择性的学习并且在模型输出时将输出序列与之进行关联

换一个角度而言,输出序列中的每一项的生成概率取决于在输入序列中选择了哪些项。

“在文本翻译任务上,使用attention机制的模型每生成一个词时都会在输入序列中找出一个与之最相关的词集合。之后模型根据当前的上下文向量 (context vectors) 和所有之前生成出的词来预测下一个目标词。

… 它将输入序列转化为一堆向量的序列并自适应地从中选择一个子集来解码出目标翻译文本。这感觉上像是用于文本翻译的神经网络模型需要“压缩”输入文本中的所有信息为一个固定长度的向量,不论输入文本的长短。”

— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015

虽然模型使用attention机制之后会增加计算量,但是性能水平能够得到提升。另外,使用attention机制便于理解在模型输出过程中输入序列中的信息是如何影响最后生成序列的。这有助于我们更好地理解模型的内部运作机制以及对一些特定的输入-输出进行debug。

“论文提出的方法能够直观地观察到生成序列中的每个词与输入序列中一些词的对齐关系,这可以通过对标注 (annotations) 权重参数可视化来实现…每个图中矩阵的每一行表示与标注相关联的权重。由此我们可以看出在生成目标词时,源句子中的位置信息会被认为更重要。”

— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015

三、大型图片带来的问题

被广泛应用于计算机视觉领域的卷积神经网络模型同样存在类似的问题: 对于特别大的图片输入,模型学习起来比较困难。

由此,一种启发式的方法是将在模型做预测之前先对大型图片进行某种近似的表示。

“人类的感知有一个重要的特性是不会立即处理外界的全部输入,相反的,人类会将注意力专注于所选择的部分来得到所需要的信息,然后结合不同时间段的局部信息来建立一个内部的场景表示,从而引导眼球的移动及做出决策。”

— Recurrent Models of Visual Attention, 2014

这种启发式方法某种程度上也可以认为是考虑了attention,但在这篇博文中,这种方法并不认为是基于attention机制的。

基于attention机制的相关论文如下:

四、基于Attention模型的应用实例

这部分将列举几个具体的应用实例,介绍attention机制是如何用在LSTM/RNN模型来进行序列预测的。

1. Attention在文本翻译任务上的应用

文本翻译这个实例在前面已经提过了。

给定一个法语的句子作为输入序列,需要输出翻译为英语的句子。Attention机制被用在输出输出序列中的每个词时会专注考虑输入序列中的一些被认为比较重要的词。

我们对原始的编码器-解码器模型进行了改进,使其有一个模型来对输入内容进行搜索,也就是说在生成目标词时会有一个编码器来做这个事情。这打破了之前的模型是基于将整个输入序列强行编码为一个固定长度向量的限制,同时也让模型在生成下一个目标词时重点考虑输入中相关的信息。

— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015

Attention在文本翻译任务(输入为法语文本序列,输出为英语文本序列)上的可视化(图片来源于Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015

2. Attention在图片描述上的应用

与之前启发式方法不同的是,基于序列生成的attention机制可以应用在计算机视觉相关的任务上,帮助卷积神经网络重点关注图片的一些局部信息来生成相应的序列,典型的任务就是对一张图片进行文本描述。

给定一张图片作为输入,输出对应的英文文本描述。Attention机制被用在输出输出序列的每个词时会专注考虑图片中不同的局部信息。

我们提出了一种基于attention的方法,该方法在3个标准数据集上都取得了最佳的结果……同时展现了attention机制能够更好地帮助我们理解模型地生成过程,模型学习到的对齐关系与人类的直观认知非常的接近(如下图)。

— Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016

Attention在图片描述任务(输入为图片,输出为描述的文本)上的可视化(图片来源于Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016)

3. Attention在语义蕴涵 (Entailment) 中的应用

给定一个用英文描述的前提和假设作为输入,输出假设与前提是否矛盾、是否相关或者是否成立。

举个例子:

前提:在一个婚礼派对上拍照

假设:有人结婚了

该例子中的假设是成立的。

Attention机制被用于关联假设和前提描述文本之间词与词的关系。

我们提出了一种基于LSTM的神经网络模型,和把每个输入文本都独立编码为一个语义向量的模型不同的是,该模型同时读取前提和假设两个描述的文本序列并判断假设是否成立。我们在模型中加入了attention机制来找出假设和前提文本中词/短语之间的对齐关系。……加入attention机制能够使模型在实验结果上有2.6个点的提升,这是目前数据集上取得的最好结果…

— Reasoning about Entailment with Neural Attention, 2016

Attention在语义蕴涵任务(输入是前提文本,输出是假设文本)上的可视化(图片来源于Reasoning about Entailment with Neural Attention, 2016)

4. Attention在语音识别上的应用

给定一个英文的语音片段作为输入,输出对应的音素序列。

Attention机制被用于对输出序列的每个音素和输入语音序列中一些特定帧进行关联。

…一种基于attention机制的端到端可训练的语音识别模型,能够结合文本内容和位置信息来选择输入序列中下一个进行编码的位置。该模型有一个优点是能够识别长度比训练数据长得多的语音输入。

— Attention-Based Models for Speech Recognition, 2015.

Attention在语音识别任务(输入是音帧,输出是音素的位置)上的可视化(图片来源于Attention-Based Models for Speech Recognition, 2015)

5. Attention在文本摘要上的应用

给定一篇英文文章作为输入序列,输出一个对应的摘要序列。

Attention机制被用于关联输出摘要中的每个词和输入中的一些特定词。

… 在最近神经网络翻译模型的发展基础之上,提出了一个用于生成摘要任务的基于attention的神经网络模型。通过将这个概率模型与一个生成式方法相结合来生成出准确的摘要。

— A Neural Attention Model for Abstractive Sentence Summarization, 2015

Attention在文本摘要任务(输入为文章,输出为文本摘要)上的可视化(图片来源于A Neural Attention Model for Abstractive Sentence Summarization, 2015)

五、Attention的数学解释

1. 原来的Encoder–Decoder

在这个模型中,encoder只将最后一个输出递给了decoder,这样一来,decoder就相当于对输入只知道梗概意思,而无法得到更多输入的细节,比如输入的位置信息。所以想想就知道了,如果输入的句子比较短、意思比较简单,翻译起来还行,长了复杂了就做不好了嘛。

2. 对齐问题

前面说了,只给我递来最后一个输出,不好;但如果把每个step的输出都传给我,又有一个问题了,怎么对齐?

什么是对齐?比如说英文翻译成中文,假设英文有10个词,对应的中文翻译只有6个词,那么就有了哪些英文词对哪些中文词的问题了嘛。

传统的翻译专门有一块是搞对齐的,是一个比较独立的task(传统的NLP基本上每一块都是独立的task啦)。

3. attention机制

我们从输出端,即decoder部分,倒过来一步一步看公式。

$$  S_t=f(S_{t-1}, y_{t-1}, c_t)  \tag{1} $$

$S_t$是指decoder在$t$时刻的状态输出,$S_{t-1}$是指decoder在$t-1$时刻的状态输出,$y_{t-1}$是$t-1$时刻的label(注意是label,不是我们输出的$y$),$c_t$看下一个公式,$f$是一个RNN。

$$    {c_{t}} = \sum\limits_{j = 1}^{{T_x}} {{a_{tj}}{h_j}}     \tag{2}   $$

$h_j$是指第$j$个输入在encoder里的输出,$a_{tj}$是一个权重

$$    {a_{tj}} = \frac{{exp \left( {{e_{tj}}} \right)}}{{\sum\nolimits_{k = 1}^{{T_x}} {exp \left( {{e_{tk}}} \right)} }}   \tag{3}$$

这个公式跟softmax是何其相似,道理是一样的,是为了得到条件概率$P(a|e)$,这个$a$的意义是当前这一步decoder对齐第$j$个输入的程度。

最后一个公式,

$$  e_{tj} = g(S_{t-1}, h_j) = V\cdot \tanh { \left( W\cdot h_j+U\cdot S_{t-1}+b \right)  }    \tag{4}$$

这个$g$可以用一个小型的神经网络来逼近,它用来计算$S_{t-1}$, $h_j$这两者的关系分数,如果分数大则说明关注度较高,注意力分布就会更加集中在这个输入单词上,这个函数在文章Neural Machine Translation by Jointly Learning to Align and Translate(2014)中称之为校准模型(alignment model),文中提到这个函数是RNN前馈网络中的一系列参数,在训练过程会训练这些参数, 基于Attention-Based LSTM模型的文本分类技术的研究(2016)给出了上式的右侧部分作为拓展。

好了,把四个公式串起来看,这个attention机制可以总结为一句话:当前一步输出$S_t$应该对齐哪一步输入,主要取决于前一步输出$S_{t-1}$和这一步输入的encoder结果$h_j$。

看了这个方法的感受是,计算力发达的这个年代,真是什么复杂的东西都有人敢试了啊。这要是放在以前,得跑多久才能收敛啊......

神经网络搞NLP虽然还有诸多受限的地方,但这种end-to-end 的one task方式,太吸引人,有前途。

进一步的阅读

如果你想进一步地学习如何在LSTM/RNN模型中加入attention机制,可阅读以下论文:

目前Keras官方还没有单独将attention模型的代码开源,下面有一些第三方的实现:

总结

通过这篇博文,你应该学习到了attention机制是如何应用在LSTM/RNN模型中来解决序列预测存在的问题。

具体而言,采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题:不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差)。而attention机制则克服了上述问题,原理是在模型输出时会选择性地专注考虑输入中的对应相关的信息。使用attention机制的方法被广泛应用在各种序列预测任务上,包括文本翻译、语音识别等。

感谢原作者Jason Brownlee。原文链接见:Attention in Long Short-Term Memory Recurrent Neural Networks

转载:http://www.jeyzhang.com/understand-attention-in-rnn.html

Deep Learning基础--理解LSTM/RNN中的Attention机制的更多相关文章

  1. 理解LSTM/RNN中的Attention机制

    转自:http://www.jeyzhang.com/understand-attention-in-rnn.html,感谢分享! 导读 目前采用编码器-解码器 (Encode-Decode) 结构的 ...

  2. Deep Learning基础--理解LSTM网络

    循环神经网络(RNN) 人们的每次思考并不都是从零开始的.比如说你在阅读这篇文章时,你基于对前面的文字的理解来理解你目前阅读到的文字,而不是每读到一个文字时,都抛弃掉前面的思考,从头开始.你的记忆是有 ...

  3. LSTM/RNN中的Attention机制

    一.解决的问题 采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题,不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差). 注意下图中,ax ...

  4. Deep Learning基础--CNN的反向求导及练习

    前言: CNN作为DL中最成功的模型之一,有必要对其更进一步研究它.虽然在前面的博文Stacked CNN简单介绍中有大概介绍过CNN的使用,不过那是有个前提的:CNN中的参数必须已提前学习好.而本文 ...

  5. Deep Learning基础--参数优化方法

    1. 深度学习流程简介 1)一次性设置(One time setup)          -激活函数(Activation functions) - 数据预处理(Data Preprocessing) ...

  6. 深度学习中的Attention机制

    1.深度学习的seq2seq模型 从rnn结构说起 根据输出和输入序列不同数量rnn可以有多种不同的结构,不同结构自然就有不同的引用场合.如下图, one to one 结构,仅仅只是简单的给一个输入 ...

  7. Deep Learning基础--word2vec 中的数学原理详解

    word2vec 是 Google 于 2013 年开源推出的一个用于获取 word vector 的工具包,它简单.高效,因此引起了很多人的关注.由于 word2vec 的作者 Tomas Miko ...

  8. Deep Learning基础--随时间反向传播 (BackPropagation Through Time,BPTT)推导

    1. 随时间反向传播BPTT(BackPropagation Through Time, BPTT) RNN(循环神经网络)是一种具有长时记忆能力的神经网络模型,被广泛用于序列标注问题.一个典型的RN ...

  9. 对循环神经网络参数的理解|LSTM RNN Input_size Batch Sequence

    在很多博客和知乎中我看到了许多对于pytorch框架中RNN接口的一些解析,但都较为浅显甚至出现一些不准确的理解,在这里我想阐述下我对于pytorch中RNN接口的参数的理解. 我们经常看到的RNN网 ...

随机推荐

  1. wireshark系列之wireshark简介

    前言:为什么要学wireshark?工欲善其事必先利其器,wireshark是一款工具软件,主要作用是抓取数据封包,可以帮助我们更加直观更加具象的学习各种网路协议(http.TLS.TCP.UDP.I ...

  2. 英文报道:China challenged Australian warships in South China Sea, reports say

    学习地道新闻英语表达,以下文章来自CNN By Ben Westcott and Jamie Tarabay, CNN Updated 0830 GMT (1630 HKT) April 20, 20 ...

  3. pbuilder编译构建工具分析

    1. 简介 pbuilder(personal Debian package builder)是ubuntu环境下维护debian包的专业工具,能够为每个deb包创建纯净的编译构建环境,自动解析和安装 ...

  4. Linux内核分析3

    周子轩创作品转载请注明出处  <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 gdb跟踪start_ke ...

  5. C++11——引入的新关键字

    1.auto auto是旧关键字,在C++11之前,auto用来声明自动变量,表明变量存储在栈,很少使用.在C++11中被赋予了新的含义和作用,用于类型推断. auto关键字主要有两种用途:一是在变量 ...

  6. Zookeeper(四)Hadoop HA高可用集群搭建

    一.高可就集群搭建 1.集群规划 2.集群服务器准备 (1) 修改主机名(2) 修改 IP 地址(3) 添加主机名和 IP 映射(4) 同步服务器时间(5) 关闭防火墙(6) 配置免密登录(7) 安装 ...

  7. MapReduce(二)常用三大组件

    mapreduce三大组件:Combiner\Sort\Partitioner 默认组件:排序,分区(不设置,系统有默认值) 一.mapreduce中的Combiner 1.什么是combiner C ...

  8. pxp

    Time Limit: 2000 ms Memory Limit: 512 MB Description 给定 \(n\), 求\(\sum\limits_{p,q∈primes}[pq≤n]\) ( ...

  9. 2017 Multi-University Training Contest - 2

    HDU 6045 #pragma comment(linker, "/STACK:102400000,102400000") #include <cstdio> #in ...

  10. lnmp集成环境tp nginx vhost配置

    server { listen 80; server_name xxx; charset utf-8; root /home/wwwroot/axxx/public; index index.html ...