tf.keras.layers.Attention( View source on GitHub )

Dot-product attention layer, a.k.a. Luong-style attention.

Inherits From: Layer, Module

tf.keras.layers.Attention(
use_scale=False, score_mode='dot', **kwargs
)

Inputs are query tensor of shape [batch_size, Tq, dim], value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).

Use scores to calculate a distribution with shape [batch_size, Tq, Tv]: distribution = tf.nn.softmax(scores).

Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]: return tf.matmul(distribution, value).

Args
`use_scale` If True, will create a scalar variable to scale the attention scores.
`dropout` Float between 0 and 1. Fraction of the units to drop for the attention scores. Defaults to 0.0.
`score_mode` Function to use to compute attention scores, one of {"dot", "concat"}. "dot" refers to the dot product between the query and key vectors. "concat" refers to the hyperbolic tangent of the concatenation of the query and key vectors. **Call arguments**
`inputs` List of the following tensors:
`query`: Query Tensor of shape [batch_size, Tq, dim].
`value`: Value Tensor of shape [batch_size, Tv, dim].
`key`: Optional key Tensor of shape [batch_size, Tv, dim]. If not given, will use value for both key and value, which is the most common case. **mask List of the following tensors:**
`query_mask`: A boolean mask Tensor of shape [batch_size, Tq]. If given, the output will be zero at the positions where mask==False.
`value_mask`: A boolean mask Tensor of shape [batch_size, Tv]. If given, will apply the mask such that values at positions where mask==False do not contribute to the result.
`return_attention_scores` bool, it True, returns the attention scores (after masking and softmax) as an additional output argument.
`training` Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout).
`use_causal_mask` Boolean. Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i. This prevents the flow of information from the future towards the past. Defaults to False. **Output**

Attention outputs of shape [batch_size, Tq, dim]. [Optional] Attention scores after masking and softmax with shape [batch_size, Tq, Tv].

The meaning of query, value and key depend on the application. In the case of text similarity, for example, query is the sequence embeddings of the first piece of text and value is the sequence embeddings of the second piece of text. key is usually the same tensor as value.

Here is a code example for using Attention in a CNN+Attention network:

# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32') # Embedding lookup.
token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input) # CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
filters=100,
kernel_size=4,
# Use 'same' padding so outputs have the same shape as inputs.
padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings) # Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.Attention()(
[query_seq_encoding, value_seq_encoding]) # Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
query_value_attention_seq) # Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
[query_encoding, query_value_attention]) # Add DNN layers, and create Model.
# ...

tf.keras.layers.Attention: Dot-product attention layer, a.k.a. Luong-style attention.的更多相关文章

  1. tf.keras.layers.MaxPool2D 简介

    tf.keras.layers.Max2D( pool_size=(2, 2), strides=None, padding='valid', data_format=None ) pool_size ...

  2. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  4. Tensorflow2(一)深度学习基础和tf.keras

    代码和其他资料在 github 一.tf.keras概述 首先利用tf.keras实现一个简单的线性回归,如 \(f(x) = ax + b\),其中 \(x\) 代表学历,\(f(x)\) 代表收入 ...

  5. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  6. 【tf.keras】实现 F1 score、precision、recall 等 metric

    tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...

  7. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  8. tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑

    自定义tf.keras.Model需要注意的点 model.save() subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_form ...

  9. 【tf.keras】AdamW: Adam with Weight decay

    论文 Decoupled Weight Decay Regularization 中提到,Adam 在使用时,L2 与 weight decay 并不等价,并提出了 AdamW,在神经网络需要正则项时 ...

  10. tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`

    经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...

随机推荐

  1. Mybatis 框架课程第四天

    目录 1 Mybatis 延迟加载策略 1.1 何为延迟加载 1.2 实现需求 1.3 使用 assocation 实现延迟加载 1.3.1 账户的持久层 DAO 接口 1.3.2 账户的持久层映射文 ...

  2. 【记录】BASE64|解决JS和C++中文传输乱码,内含两种语言的Base64编码解码的代码

    JS 解决方法来源于知乎新码笔记的文章 function b64Encode(str) { return btoa(unescape(encodeURIComponent(str))); } func ...

  3. MySql技术之"虚拟表增加索引"

    一.虚拟表增加索引 创建虚拟表,并且增加SKU索引:INDEX idx_sku (sku) CREATE TEMPORARY TABLE t_sku_analy_temp ( sku VARCHAR( ...

  4. Python 3.14 新特性盘点,更新了些什么?

    Python 3.14.0 稳定版将于 2025 年 10 月正式发布,目前已进入 beta 测试阶段.这意味着在往后的几个月里,3.14 的新功能已冻结,不再合入新功能(除了修复问题和完善文档). ...

  5. C# 从PDF文档中提取图片

    当 PDF 文件中包含有价值的图片,如艺术画作.设计素材.报告图表等,提取图片可以将这些图像资源进行单独保存,方便后续在不同的项目中使用,避免每次都要从 PDF 中查找.本文将介绍如何使用C#通过代码 ...

  6. 数字孪生工厂实战指南:基于Unreal Engine/Omniverse的虚实同步系统开发

    引言:工业元宇宙的基石技术 在智能制造2025与工业元宇宙的交汇点,数字孪生技术正重塑传统制造业.本文将手把手指导您构建基于Unreal Engine 5.4与NVIDIA Omniverse的实时数 ...

  7. Torch-Pruning工具箱

    Torch-Pruning 通道剪枝网络实现加速的工作. Torch pruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪 ...

  8. C++ 迭代器(STL迭代器)iterator详解

    要访问顺序容器和关联容器中的元素,需要通过"迭代器(iterator)"进行,迭代器是一个变量,相当于容器和操作容器的算法之间的中介.迭代器可以指向容器中的某个元素,通过迭代器就可 ...

  9. VS2019 配置 protobuf3.8.0

    1.下载protobuf3.8.0 https://github.com/protocolbuffers/protobuf/releases/tag/v3.8.0 2.准备工作 解压文件并在同级目录建 ...

  10. codeup之字符串求最大值

    Description 从键盘上输入3个字符串,求出其中最大者. Input 输入3行,每行均为一个字符串. Output 一行,输入三个字符串中最大者. Sample Input Copy Engl ...