tf.keras.layers.Attention( View source on GitHub )

Dot-product attention layer, a.k.a. Luong-style attention.

Inherits From: Layer, Module

tf.keras.layers.Attention(
use_scale=False, score_mode='dot', **kwargs
)

Inputs are query tensor of shape [batch_size, Tq, dim], value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).

Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).

Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]: return tf.matmul(distribution, value).

Args
`use_scale` If True, will create a scalar variable to scale the attention scores.
`dropout` Float between 0 and 1. Fraction of the units to drop for the attention scores. Defaults to 0.0.
`score_mode` Function to use to compute attention scores, one of {"dot", "concat"}. "dot" refers to the dot product between the query and key vectors. "concat" refers to the hyperbolic tangent of the concatenation of the query and key vectors. **Call arguments**
`inputs` List of the following tensors:
`query`: Query Tensor of shape [batch_size, Tq, dim].
`value`: Value Tensor of shape [batch_size, Tv, dim].
`key`: Optional key Tensor of shape [batch_size, Tv, dim]. If not given, will use value for both key and value, which is the most common case. **mask List of the following tensors:**
`query_mask`: A boolean mask Tensor of shape [batch_size, Tq]. If given, the output will be zero at the positions where mask==False.
`value_mask`: A boolean mask Tensor of shape [batch_size, Tv]. If given, will apply the mask such that values at positions where mask==False do not contribute to the result.
`return_attention_scores` bool, it True, returns the attention scores (after masking and softmax) as an additional output argument.
`training` Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout).
`use_causal_mask` Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i. This prevents the flow of information from the future towards the past. Defaults to False. **Output**

Attention outputs of shape [batch_size, Tq, dim]. [Optional] Attention scores after masking and softmax with shape [batch_size, Tq, Tv].

The meaning of query, value and key depend on the application. In the case of text similarity, for example, query is the sequence embeddings of the first piece of text and value is the sequence embeddings of the second piece of text. key is usually the same tensor as value.

Here is a code example for using Attention in a CNN+Attention network:

# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32') # Embedding lookup.
token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input) # CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
filters=100,
kernel_size=4,
# Use 'same' padding so outputs have the same shape as inputs.
padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings) # Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.Attention()(
[query_seq_encoding, value_seq_encoding]) # Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
query_value_attention_seq) # Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
[query_encoding, query_value_attention]) # Add DNN layers, and create Model.
# ...

tf.keras.layers.Attention: Dot-product attention layer, a.k.a. Luong-style attention.的更多相关文章

  1. tf.keras.layers.MaxPool2D 简介

    tf.keras.layers.Max2D( pool_size=(2, 2), strides=None, padding='valid', data_format=None ) pool_size ...

  2. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  4. Tensorflow2(一)深度学习基础和tf.keras

    代码和其他资料在 github 一.tf.keras概述 首先利用tf.keras实现一个简单的线性回归,如 \(f(x) = ax + b\),其中 \(x\) 代表学历,\(f(x)\) 代表收入 ...

  5. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  6. 【tf.keras】实现 F1 score、precision、recall 等 metric

    tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...

  7. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  8. tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑

    自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...

  9. 【tf.keras】AdamW: Adam with Weight decay

    论文 Decoupled Weight Decay Regularization 中提到,Adam 在使用时,L2 与 weight decay 并不等价,并提出了 AdamW,在神经网络需要正则项时 ...

  10. tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`

    经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...

随机推荐

  1. 跨境电商 | 踏浪出海:我的Allegro跨境电商实战全景

    作者:追梦1819 (同名公众号),本职:高级软件工程师:副业:Allegro跨境电商.运营同名公众号(左上二维码),专注分享成长.跨境电商经历与经验. 版权声明:本文为博主原创文章,转载请附上博文链 ...

  2. 信息资源管理综合题之“X大师设计瓷砖给公司生产并检验和销售”

    一.X大师是一位德高望重的设计大师,现在为A公司设计了一套陶瓷地砖,准备交由B公司进行批量生产.B企业按照GB/T 3810.13-1999陶瓷砖试验方法对该产品进行检测,检测合格,并推向市场 1.按 ...

  3. 鸿蒙Next开发实战教程:实现抖音长按快速评论特效

    开篇点题,今天玩点花的. 不知道大家有没有发现,抖音上的评论键长按会弹出一排表情框用于快速评论,不过现在鸿蒙原生版的抖音还没有这个功能,今天幽蓝君就小试牛刀,在鸿蒙上做一下这个功能,也是应一位友友的私 ...

  4. C++基础——引用和指针篇

    一.指针(Pointer) 定义: 指针是一个变量,用于存储另一个变量的地址. 基本用法: #include <iostream> using namespace std; int mai ...

  5. System.Drawing.Point与System.Windows.Point的异同

    在C#中,System.Drawing.Point 和 System.Windows.Point 是两个不同的结构,分别属于不同的命名空间,用于表示二维平面中的点.尽管它们的功能相似,但在使用场景和实 ...

  6. JavaScript在SublimeText中的配置

    1.系统安装配置Node.js https://nodejs.org/en/ 2.Sublime 依次点击 菜单栏 Tools => Build System => New Build S ...

  7. VScode中的自定义模板

    1.背景 在用bootstrap框架,写前端代码时,由于每个页面都有固定的模板格式,比如都包含CDN等,所以在每次写代码的时候,都要重复写这些,效率比较低下.幸运的是,VScode中可以编辑各个语言的 ...

  8. C语言函数指针解析

    C语言函数指针解析 一.函数指针的本质 函数指针是存储函数内存地址的变量,它允许程序在运行时动态调用不同的函数.与数据指针不同,函数指针指向的是可执行代码段. /* 典型声明方式 */ int (*f ...

  9. 为什么阿里的dubbo注册中心要放弃zookeeper, 而用Nacos?

    首先,那么为什么说zookeeper不适合做服务注册中心呢? 从CAP角度来看 有个思考,从CAP角度考虑,服务注册中心是CP系统还是AP系统呢? 首先,服务注册中心是为了服务间调用服务的,那么绝对不 ...

  10. VSCode将本地项目代码上传到gitee中

    1.创建远程仓库,这个就是该仓库的地址   2.查看git的版本 git --version 3.使用git init命令初始化git 4.使用git status命令来查看文件是否被修改  : gi ...