self-attention详解
编写你自己的 Keras 层
对于简单、无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现。但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层。
这是一个 Keras2.0 中,Keras 层的骨架(如果你用的是旧的版本,请更新到新版)。你只需要实现三个方法即可:
build(input_shape): 这是你定义权重的地方。这个方法必须设self.built = True,可以通过调用super([Layer], self).build()完成。call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入call的第一个参数:输入张量,除非你希望你的层支持masking。compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状
本文主要讲解Self_attention方面的内容,这方面的知识是建立在attention机制之上的,因此若读者不了解attention mechanism的话,希望你们能去看我的关于深入理解attention机制。本人也将在这里稍微的解释一下。
对于encoder-decoder模型,decoder的输入包括(注意这里是包括)encoder的输出。但是根据常识来讲,某一个输出并不需要所有encoder信息,而是只需要部分信息。这句话就是attention的精髓所在。怎么理解这句话呢?举个例子来说:假如我们正在做机器翻译,将“I am a student”翻译成中文“我是一个学生”。根据encoder-decoder模型,在输出“学生”时,我们用到了“我”“是”“一个”以及encoder的输出。但事实上,我们或许并不需要“I am a ”这些无关紧要的信息,而仅仅只需要“student”这个词的信息就可以输出“学生”(或者说“I am a”这些信息没有“student”重要)。这个时候就需要用到attention机制来分别为“I”、“am”、“a”、“student”赋一个权值了。例如分别给“I am a”赋值为0.1,给“student”赋值剩下的0.7,显然这时student的重要性就体现出来了。具体怎么操作,我这里就不在讲了。
2、self-attention
self-attention显然是attentio机制的一种。上面所讲的attention是输入对输出的权重,例如在上文中,是I am a student 对学生的权重。self-attention则是自己对自己的权重,例如I am a student分别对am的权重、对student的权重。之所以这样做,是为了充分考虑句子之间不同词语之间的语义及语法联系。
那么这个权值应该怎么计算呢?我在别处看到的图片以及我自己的理解如下:

注释:q\k\v分别对应attention机制中的Q\K\V,它们是通过输入词向量分别和W(Q)、W(K)、W(V)做乘积得到的。其目的主要是计算权值。

注释:q与k做点乘、然后归一化,就得到权值(乘积越大,相似度越高,权值越高)。得到的两个权值分别与v相乘后,再相加就是输出。同理就可以得到另一个单词的输出。
以上是一个单词一个单词的输出,如果写成矩阵形式就是Q*K,得到的矩阵归一化直接得到权值。

#self-attentiom模型的搭建:
from keras.preprocessing import sequence
from keras.datasets import imdb
from matplotlib import pyplot as plt
import pandas as pd from keras import backend as K
from keras.engine.topology import Layer class Self_Attention(Layer): def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(Self_Attention, self).__init__(**kwargs) def build(self, input_shape):
# 为该层创建一个可训练的权重
#inputs.shape = (batch_size, time_steps, seq_len)
self.kernel = self.add_weight(name='kernel',
shape=(3,input_shape[2], self.output_dim),
initializer='uniform',
trainable=True) super(Self_Attention, self).build(input_shape) # 一定要在最后调用它 def call(self, x):
WQ = K.dot(x, self.kernel[0])
WK = K.dot(x, self.kernel[1])
WV = K.dot(x, self.kernel[2]) print("WQ.shape",WQ.shape) print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) #64*5是归一化的值,不同问题不一样 QK = K.softmax(QK) print("QK.shape",QK.shape) V = K.batch_dot(QK,WV) return V def compute_output_shape(self, input_shape): return (input_shape[0],input_shape[1],self.output_dim)
在Keras上对IMDB进行简单的测试(不做Mask):
from __future__ import print_function
from keras.preprocessing import sequence
from keras.datasets import imdb max_features = 20000
maxlen = 80
batch_size = 32 print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences') print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape) from keras.models import Model
from keras.layers import * S_inputs = Input(shape=(None,), dtype='int32')
embeddings = Embedding(max_features, 128)(S_inputs)
# embeddings = Position_Embedding()(embeddings) # 增加Position_Embedding能轻微提高准确率
O_seq = Attention(8,16)([embeddings,embeddings,embeddings])
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.5)(O_seq)
outputs = Dense(1, activation='sigmoid')(O_seq) model = Model(inputs=S_inputs, outputs=outputs)
# try using different optimizers and different optimizer configs
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy']) print('Train...')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=5,
validation_data=(x_test, y_test))
参考博客:
https://blog.csdn.net/xiaosongshine/article/details/90600028
https://blog.csdn.net/cpluss/article/details/85330256
self-attention详解的更多相关文章
- Attention is all you need 论文详解(转)
一.背景 自从Attention机制在提出之后,加入Attention的Seq2Seq模型在各个任务上都有了提升,所以现在的seq2seq模型指的都是结合rnn和attention的模型.传统的基于R ...
- Transform详解(超详细) Attention is all you need论文
一.背景 自从Attention机制在提出 之后,加入Attention的Seq2 Seq模型在各个任务上都有了提升,所以现在的seq2seq模型指的都是结合rnn和attention的模型.传统的基 ...
- Residual Attention Network for Image Classification(CVPR 2017)详解
一.Residual Attention Network 简介 这是CVPR2017的一篇paper,是商汤.清华.香港中文和北邮合作的文章.它在图像分类问题上,首次成功将极深卷积神经网络与人类视觉注 ...
- Attention和Transformer详解
目录 Transformer引入 Encoder 详解 输入部分 Embedding 位置嵌入 注意力机制 人类的注意力机制 Attention 计算 多头 Attention 计算 残差及其作用 B ...
- quartz配置文件详解
quartz配置文件详解(转载) quartz学习总结: 一.关于job: 用Quartz的行话讲,作业是一个执行任务的简单Java类.任务可以是任何Java代码.只需你实现org.qu ...
- Android特效 五种Toast详解
Toast是Android中用来显示显示信息的一种机制,和Dialog不一样的是,Toast是没有焦点的,而且Toast显示的时间有限,过一定的时间就会自动消失.而且Toast主要用于向用户显示提示消 ...
- linux syslog详解
linux syslog详解 分三部分 一.syslog协议介绍 二.syslog函数 三.linux syslog配置 一.syslog协议介绍 1.介绍 在Unix类操作系统上,syslog广 ...
- centos7.2环境nginx+mysql+php-fpm+svn配置walle自动化部署系统详解
centos7.2环境nginx+mysql+php-fpm+svn配置walle自动化部署系统详解 操作系统:centos 7.2 x86_64 安装walle系统服务端 1.以下安装,均在宿主机( ...
- Transformer各层网络结构详解!面试必备!(附代码实现)
1. 什么是Transformer <Attention Is All You Need>是一篇Google提出的将Attention思想发挥到极致的论文.这篇论文中提出一个全新的模型,叫 ...
- seq2seq模型详解及对比(CNN,RNN,Transformer)
一,概述 在自然语言生成的任务中,大部分是基于seq2seq模型实现的(除此之外,还有语言模型,GAN等也能做文本生成),例如生成式对话,机器翻译,文本摘要等等,seq2seq模型是由encoder, ...
随机推荐
- [TJOI2015]弦论(第k小子串)
题意: 对于一个给定的长度为n的字符串,求出它的第k小子串. 有参数t,t为0则表示不同位置的相同子串算作一个,t为1则表示不同位置的相同子串算作多个. 题解: 首先,因为t的原因,后缀数组较难实现, ...
- scrapy 分布式爬虫- RedisSpider
爬去当当书籍信息 多台机器同时爬取,共用一个redis记录 scrapy_redis 带爬取的request对象储存在redis中,每台机器读取request对象并删除记录,经行爬取.实现分布式爬虫 ...
- 安装less/sass
安装sass npm i node-sass 安装wepy-compiler-sass插件 npm install wepy-compiler-sass --save-dev 在我的项目中使用才有用.
- 【原创】go语言学习(八)切片
目录: 切片定义 切片基本操作 切片传参 make和new的区别 切片定义 1. 切片是基于数组类型做的一层封装.它非常灵活,可以自动扩容. var a []int //定义一个int类型的空切⽚ 2 ...
- P4410 [HNOI2009]无归岛
P4410 [HNOI2009]无归岛 显然这还是一个仙人掌图 对于同一个岛上的任意两个生物,他们有且仅有一个公共朋友 要求求最大独立集,和树形dp一样,遇到环时单独提出来处理一下就好了 #inclu ...
- spring源码分析:PropertyPlaceholderConfigurer
简介 最近工作中需要使用zookeeper配置中心管理各系统的配置,也就是需要在项目启动时,加载zookeeper中节点的子节点的数据(例如数据库的地址,/config/db.properties/d ...
- 面试题小议---BY gremount
Problem 1: 两个烧杯,一个放糖一个放盐,用勺子舀一勺糖到盐,搅拌均匀,然后舀一勺混合物会放糖的烧杯,问你两个烧杯哪个杂质多? 提示:相同.(1)可以用一个特殊数据计算一下,可以得到两个烧杯 ...
- 网络爬虫requests-bs4-re-1
最近了解了爬虫,嗯--------,有时候会搞得有点头晕. 跟着线上老师实现了两个实例.可以用python下载源代码玩玩,爬淘宝的很刺激,虽然违反了ROBOTS协议. GIT地址
- 如果项目在IIS发布后,出现System.ComponentModel.Win32Exception: 拒绝访问。
如果项目在IIS发布后,出现System.ComponentModel.Win32Exception: 拒绝访问. 那么就试试下面的办法. 步骤如下: 应用程序池=>设置应用程序池默认设置 将标 ...
- (三)OpenCV-Python学习—图像平滑
由于种种原因,图像中难免会存在噪声,需要对其去除.噪声可以理解为灰度值的随机变化,即拍照过程中引入的一些不想要的像素点.噪声可分为椒盐噪声,高斯噪声,加性噪声和乘性噪声等,参见:https://zhu ...