Attention is all your need 谷歌的超强特征提取网络——Transformer
过年放了七天假,每年第一件事就是立一个flag——希望今年除了能够将技术学扎实之外,还希望能够将所学能够用来造福社会,好像flag立得有点大了。没关系,套用一句电影台词为自己开脱一下——人没有梦想,和咸鱼有什么区别。闲话至此,进入今天主题:Transformer。谷歌于2017年提出Transformer网络架构,此网络一经推出就引爆学术界。目前,在NLP领域,Transformer模型被认为是比CNN,RNN都要更强的特征提取器。
Transformer算法简介
Transformer引入了self-attention机制,同时还借鉴了CNN领域中残差机制(Residuals),由于以上原因导致transformer有如下优势:
- 模型表达能力较强,由于self-attention机制考虑到了句子之中词与词之间的关联,
- 抛弃了RNN的循环结构,同时借用了CNN中的残差结构加快了模型的训练速度。
接下来我们来看看transformer的一些细节:
首先Scaled Dot-Product Attention步骤是transformer的精髓所在,作者引入Q,W,V参数通过点乘相识度去计算句子中词与词之间的关联重要程度。其大致过程如图所示,笔者将会在实战部分具体介绍此过程如何实现。
Scaled Dot-Product Attention第二个是muti-head步骤,直白的解释就是将上面的Scaled Dot-Product Attention步骤重复执行,然后将每次执行的结果拼接起来,需要注意的是每次重复执行Scaled Dot-Product Attention步骤的参数并不共享。
Multi-Head- 第三个步骤就是残差网络结构——将muti-head步骤的输出和原始输入之间相加。这里不明白的可以参考笔者之前介绍残差网络的文章。
接下来就是实战部分,实战部分只使用了muti-head attention或者说是self-attention的向量表示作为最终特征进行文本分类。
Transformer文本分类实战
数据载入
下方代码的作用是将情感分析数据读入,格式为一句话和一个label:
sen_1 : 1
sen_2 : 0
1代表正面情绪,0代表负面情绪。
#! -*- coding: utf-8 -*-
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
from keras.preprocessing import sequence
from keras.layers import *
from keras import Model
from keras.callbacks import TensorBoard
data = np.load("imdb.npz")
x_test = data["x_test"]
x_train = data["x_train"]
y_test = data["y_test"]
y_train = data["y_train"]
数据预处理
由于文本数据长短不一,下面代码可将数据padding到相同的长度。
from itertools import chain
all_word = list(chain.from_iterable(list(x_train)))
all_word = set(all_word)
max_features = len(all_word)
data_train = sequence.pad_sequences(x_train,200)
Self-attention
这里详细介绍一下模型最关键的部分Scaled Dot-Product Attention的构建过程,如图一 Scaled Dot-Product Attention:

- 1.首先申明三个待优化的参数
,
- 2.将输入X分别和
进行点乘,得到
,此过程可以理解成将同一句话中的词映射到三个不同的向量空间,这里笔者将三个不同的向量空间命名为Q空间,K空间和V空间,如图二 Query,Key,Value metrix,
图二 Query,Key,Value metrix - 3.然后计算Q空间的某一个词在K空间所以词向量分别点乘得分,之后将这些得分通过softmax函计算一个重要度系数。然后用计算出来的重要度系数乘上该词在V空间的词向量并加和得到该词最终的词向量表示,整个过程如图三 Softmax所示,这样就可以得到一句话经过self-attention后的向量表示
。
图三 Softmax
上述整个过程就是Scaled Dot-Product Attention,本质上考虑到了一个句子中不同词之间的关联程度,这个过程或多或少增强了句子语义的表达。下方为keras定义的self-attention层的代码,这里加入了muti-head和mask功能的实现。
class Attention(Layer):
def __init__(self, nb_head, size_per_head, **kwargs):
self.nb_head = nb_head
self.size_per_head = size_per_head
self.output_dim = nb_head * size_per_head
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
self.WQ = self.add_weight(name='WQ',
shape=(input_shape[0][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WK = self.add_weight(name='WK',
shape=(input_shape[1][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
self.WV = self.add_weight(name='WV',
shape=(input_shape[2][-1], self.output_dim),
initializer='glorot_uniform',
trainable=True)
super(Attention, self).build(input_shape)
def Mask(self, inputs, seq_len, mode='mul'):
if seq_len == None:
return inputs
else:
mask = K.one_hot(seq_len[:, 0], K.shape(inputs)[1])
mask = 1 - K.cumsum(mask, 1)
for _ in range(len(inputs.shape) - 2):
mask = K.expand_dims(mask, 2)
if mode == 'mul':
return inputs * mask
if mode == 'add':
return inputs - (1 - mask) * 1e12
def call(self, x):
# 如果只传入Q_seq,K_seq,V_seq,那么就不做Mask
# 如果同时传入Q_seq,K_seq,V_seq,Q_len,V_len,那么对多余部分做Mask
if len(x) == 3:
Q_seq, K_seq, V_seq = x
Q_len, V_len = None, None
elif len(x) == 5:
Q_seq, K_seq, V_seq, Q_len, V_len = x
# 对Q、K、V做线性变换
Q_seq = K.dot(Q_seq, self.WQ)
Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))
Q_seq = K.permute_dimensions(Q_seq, (0, 2, 1, 3))
K_seq = K.dot(K_seq, self.WK)
K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
K_seq = K.permute_dimensions(K_seq, (0, 2, 1, 3))
V_seq = K.dot(V_seq, self.WV)
V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
V_seq = K.permute_dimensions(V_seq, (0, 2, 1, 3))
# 计算内积,然后mask,然后softmax
A = K.batch_dot(Q_seq, K_seq, axes=[3, 3]) / self.size_per_head ** 0.5
A = K.permute_dimensions(A, (0, 3, 2, 1))
A = self.Mask(A, V_len, 'add')
A = K.permute_dimensions(A, (0, 3, 2, 1))
A = K.softmax(A)
# 输出并mask
O_seq = K.batch_dot(A, V_seq, axes=[3, 2])
O_seq = K.permute_dimensions(O_seq, (0, 2, 1, 3))
O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))
O_seq = self.Mask(O_seq, Q_len, 'mul')
return O_seq
def compute_output_shape(self, input_shape):
return (input_shape[0][0], input_shape[0][1], self.output_dim)
位置编码
接下来定义一个位置编码层,由于是输入是句子属于一个序列,加入位置编码会使得语义表达更准确。
class Position_Embedding(Layer):
def __init__(self, size=None, mode='sum', **kwargs):
self.size = size # 必须为偶数
self.mode = mode
super(Position_Embedding, self).__init__(**kwargs)
def call(self, x):
if (self.size == None) or (self.mode == 'sum'):
self.size = int(x.shape[-1])
batch_size, seq_len = K.shape(x)[0], K.shape(x)[1]
position_j = 1. / K.pow(10000., 2 * K.arange(self.size / 2, dtype='float32') / self.size)
position_j = K.expand_dims(position_j, 0)
position_i = K.cumsum(K.ones_like(x[:, :, 0]), 1) - 1 # K.arange不支持变长,只好用这种方法生成
position_i = K.expand_dims(position_i, 2)
position_ij = K.dot(position_i, position_j)
position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)
if self.mode == 'sum':
return position_ij + x
elif self.mode == 'concat':
return K.concatenate([position_ij, x], 2)
def compute_output_shape(self, input_shape):
if self.mode == 'sum':
return input_shape
elif self.mode == 'concat':
return (input_shape[0], input_shape[1], input_shape[2] + self.size)
而谷歌的论文直接给出了position embedding 层的公式,如下图所示。
此公式的含义是将
为
的位置映射为一个
维的位置向量,此向量的第
个元素的值就是通过上述公式算出来的
。position embeding背后的物理意义参考于参考文献第一篇:由于在数学上有以及
,这表明位置
的向量可以表示成位置
的向量的线性变换,这提供了表达
相对位置信息的可能性。
模型构建
接下来使用上方定义好的的self-attention层和position embedding层进行模型构建,这里设置的8个head,意味着将self-attention流程重复做8次,这里的代码实现不是讲8个head向量拼接,而是通过keras自带的 GlobalAveragePooling1D函数将8个head的向量求和平均一下。
K.clear_session()
callbacks = [TensorBoard("log/")]
S_inputs = Input(shape=(None,), dtype='int32')
embeddings = Embedding(max_features, 128)(S_inputs)
embeddings = Position_Embedding()(embeddings) # Position_Embedding
O_seq = Attention(8, 16)([embeddings, embeddings, embeddings])# Self Attention
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq)
model = Model(inputs=S_inputs, outputs=outputs)
model.summary()
模型的网络结构可视化输出如下:

模型训练
将之前预处理好的数据喂给模型,同时设置好batch size 和 epoch就可以跑起来了。由于笔者是使用的是笔记本的cpu,所以只跑一个epoch。
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(data_train, y_train,
batch_size=2,
epochs=1,
callbacks=callbacks,
validation_split=0.2)

结语
Transformer在各方面性能上都超过了RNN和CNN,但是其最主要的思想还是引入了self-attention,使得模型可以考虑到句子中词与词之间的相互联系,这个思想在NLP很多领域,如机器阅读(R-Net)中也曾出现。所以如何在embeding时的更好挖掘句子的语义,才是深度学习在nlp领域最需要解决的难题。
参考文献
https://spaces.ac.cn/archives/4765
https://blog.csdn.net/qq_41664845/article/details/84969266
Attention Is All You Need
作者:王鹏你妹
链接:https://www.jianshu.com/p/704893b996f9
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
Attention is all your need 谷歌的超强特征提取网络——Transformer的更多相关文章
- android.os.NetworkOnMainThreadException 在4.0之后谷歌强制要求连接网络不能在主线程进行访问
谷歌在4.0系统以后就禁止在主线程中进行网络访问了,原因是: 主线程是负责UI的响应,如果在主线程进行网络访问,超过5秒的话就会引发强制关闭, 所以这种耗时的操作不能放在主线程里.放在子线程里,而子线 ...
- 《Attention is All You Need》
https://www.jianshu.com/p/25fc600de9fb 谷歌最近的一篇BERT取得了卓越的效果,为了研究BERT的论文,我先找出了<Attention is All You ...
- 2. Attention Is All You Need(Transformer)算法原理解析
1. 语言模型 2. Attention Is All You Need(Transformer)算法原理解析 3. ELMo算法原理解析 4. OpenAI GPT算法原理解析 5. BERT算法原 ...
- 从Seq2seq到Attention模型到Self Attention
Seq2seq Seq2seq全名是Sequence-to-sequence,也就是从序列到序列的过程,是近年当红的模型之一.Seq2seq被广泛应用在机器翻译.聊天机器人甚至是图像生成文字等情境. ...
- Attention模型
李宏毅深度学习 https://www.bilibili.com/video/av9770302/?p=8 Generation 生成模型基本结构是这样的, 这个生成模型有个问题是我不能干预数据生成, ...
- 【转载】细粒度图像识别Object-Part Attention Driven Discriminative Localization for Fine-grained Image Classification
细粒度图像识别Object-Part Attention Driven Discriminative Localization for Fine-grained Image Classificatio ...
- Deep attention tracking via Reciprocative Learning
文章:Deep attention tracking via Reciprocative Learning 出自NIPS2018 文章链接:https://arxiv.org/pdf/1810.038 ...
- [深度概念]·Attention Model(注意力模型)学习笔记
此文源自一个博客,笔者用黑体做了注释与解读,方便自己和大家深入理解Attention model,写的不对地方欢迎批评指正.. 1.Attention Model 概述 深度学习里的Attention ...
- Sequence Model-week3编程题1-Neural Machine Translation with Attention
1. Neural Machine Translation 下面将构建一个神经机器翻译(NMT)模型,将人类可读日期 ("25th of June, 2009") 转换为机器可读日 ...
随机推荐
- select @@identity的用法 转
用select @@identity得到上一次插入记录时自动产生的ID 如果你使用存储过程的话,将非常简单,代码如下:SET @NewID=@@IDENTITY 说明: 在一条 INSERT.SELE ...
- 从0开始学习 GitHub 系列之「08.如何发现优秀的开源项目」
之前发过一系列有关 GitHub 的文章,有同学问了,GitHub 我大概了解了,Git 也差不多会使用了,但是 还是搞不清 GitHub 如何帮助我的工作,怎么提升我的工作效率? 问到点子上了,Gi ...
- margin负值5种应用
最近做的项目中经常会用到margin的负值,这里就总结一下关于margin负值的5种使用及相关bug的解决. 1. 在流动性布局中的应用 如WordPress的两栏式不固定布局就是使用margin负值 ...
- java 3类的继承
模板类 泛型程序设计方法 类的组合 类的继承 java只有单继承 隐藏和覆盖 用super.x调用 访问静态属性 静态属性不继承 静态成员只有一个,不会有副本 静态成员只有一个所有的超类和子类 方法的 ...
- SimpleDateFormat 以及java8 - DateTimeFormatter
https://www.cnblogs.com/zhisheng/p/9206758.html 在看的过程中有这么一条: [强制]SimpleDateFormat 是线程不安全的类,一般不要定义为 s ...
- 基于docker的php调用基于docker的mysql数据库的方法
1:建立基于docker的mysql,参考 Mac上将brew安装的MySql改用Docker执行 2:建立基于docker�php image 在当前目录,建立Dockerfile,内容如下 FRO ...
- java reference(转)
http://blog.163.com/xubin_3@126/blog/static/112987702200962211145825/ 在Java中的引用类型,是指除了基本的变量类型之外的所有类型 ...
- PLAY2.6-SCALA(三) 数据的返回与保存
1.修改默认的Content-Type 自动设置内容类型为text/plain val textResult = Ok("Hello World!") 自动设置内容类型为appli ...
- 让超出div内容的显示滚动条:overflow:auto,以及overflow其它属性
css的属性,以前没用过遇到了,记录一下: 虽然layui本来自带这个处理,但是为了灵活,抛弃layui原有的加载,只是用layui的样样式,就要使用到这个css属性 总结overflow属性: /* ...
- 2019-8-31-dotnet-使用-System.CommandLine-写命令行程序
title author date CreateTime categories dotnet 使用 System.CommandLine 写命令行程序 lindexi 2019-08-31 16:55 ...