新代码在contrib\seq2seq\python\ops\attention_decoder_fn.py

和之前代码相比 不再采用conv的方式来计算乘,直接使用乘法和linear

给出了两种attention的实现 传统的"bahdanau": additive (Bahdanau et al., ICLR'2015) Neural Machine Translation by Jointly Learning to Align and Translate

以及"luong": multiplicative (Luong et al., EMNLP'2015) Effective Approaches to Attention-based Neural Machine Translation

这里以 bahdanau为例

还是按照 Grammar as a Foreign Language的公式

对应代码里面

将input encoder outputs 也就是输入的attention states作为 attention values

也就是在prepare_attention中

attention_values = attention_states

那么attention keys 对应 W_1h_i的部分,采用linear来实现

attention_keys = layers.linear(

attention_states, num_units, biases_initializer=None, scope=scope)

在创建score function的

_create_attention_score_fn 中完整定义了计算过程

这里去掉luong的实现部分 仅仅看bahdanau部分

with variable_scope.variable_scope(name, reuse=reuse):

if attention_option == "bahdanau":

#这里对应第一个公式最右面 query_w对应W_2, query是对应d_t

query_w = variable_scope.get_variable(

"attnW", [num_units, num_units], dtype=dtype)

#对应第一个公式最左侧的v

score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype)

def attention_score_fn(query, keys, values):

"""Put attention masks on attention_values using attention_keys and query.

Args:

query: A Tensor of shape [batch_size, num_units].

keys: A Tensor of shape [batch_size, attention_length, num_units].

values: A Tensor of shape [batch_size, attention_length, num_units].

Returns:

context_vector: A Tensor of shape [batch_size, num_units].

Raises:

ValueError: if attention_option is neither "luong" or "bahdanau".

"""

if attention_option == "bahdanau":

# transform query W_2*d_t

query = math_ops.matmul(query, query_w)

# reshape query: [batch_size, 1, num_units]

query = array_ops.reshape(query, [-1, 1, num_units])

# attn_fun 对应第一个公式的最左侧结果(=左侧) math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2]) * + reduce_sum操作即是dot操作

scores = _attn_add_fun(score_v, keys, query)

# Compute alignment weights

# scores: [batch_size, length]

# alignments: [batch_size, length]

# TODO(thangluong): not normalize over padding positions.

#对应第二个公式计算softmax结果

alignments = nn_ops.softmax(scores)

# Now calculate the attention-weighted vector.

alignments = array_ops.expand_dims(alignments, 2)

#利用softmax得到的权重 计算attention向量的加权加和

context_vector = math_ops.reduce_sum(alignments * values, [1])

context_vector.set_shape([None, num_units])

#context_vector即对应 第三个公式 =的左侧

return context_vector

再看下计算出contenxt_vector之后的使用,这个方法正如论文中所说也和之前旧代码基本一致

也就是说将context和query进行concat之后通过linear映射依然得到num_units的长度 作为attention

def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse):

"""Function to compute attention vectors.

Args:

name: to label variables.

num_units: hidden state dimension.

attention_score_fn: to compute similarity between key and target states.

reuse: whether to reuse variable scope.

Returns:

attention_construct_fn: to build attention states.

"""

with variable_scope.variable_scope(name, reuse=reuse) as scope:

def construct_fn(attention_query, attention_keys, attention_values):

context = attention_score_fn(attention_query, attention_keys,

attention_values)

concat_input = array_ops.concat([attention_query, context], 1)

attention = layers.linear(

concat_input, num_units, biases_initializer=None, scope=scope)

return attention

return construct_fn

最终的使用,cell_output就是attention,而next_input是cell_input和attention的concat

# construct attention

attention = attention_construct_fn(cell_output, attention_keys,

attention_values)

cell_output = attention

# argmax decoder

cell_output = output_fn(cell_output) # logits

next_input_id = math_ops.cast(

math_ops.argmax(cell_output, 1), dtype=dtype)

done = math_ops.equal(next_input_id, end_of_sequence_id)

cell_input = array_ops.gather(embeddings, next_input_id)

# combine cell_input and attention

next_input = array_ops.concat([cell_input, attention], 1)

Dynamic attention in tensorflow的更多相关文章

  1. 论文翻译:2020_A Recursive Network with Dynamic Attention for Monaural Speech Enhancement

    论文地址:基于动态注意的递归网络单耳语音增强 论文代码:https://github.com/Andong-Li-speech/DARCN 引用格式:Li, A., Zheng, C., Fan, C ...

  2. Dynamic seq2seq in tensorflow

    v1.0中 tensorflow渐渐废弃了老的非dynamic的seq2seq接口,已经放到 tf.contrib.legacy_seq2seq目录下面. tf.contrib.seq2seq下面的实 ...

  3. 可视化展示attention(seq2seq with attention in tensorflow)

    目前实现了基于tensorflow的支持的带attention的seq2seq.基于tf 1.0官网contrib路径下seq2seq 由于后续版本不再支持attention,迁移到melt并做了进一 ...

  4. Effective Tensorflow[转]

    Effective TensorFlow Table of Contents TensorFlow Basics Understanding static and dynamic shapes Sco ...

  5. seq2seq attention

    1.seq2seq:分为encoder和decoder a.在decoder中,第一时刻输入的是上encoder最后一时刻的状态,如果用了双向的rnn,那么一般使用逆序的最后一个时刻的输出(网上说实验 ...

  6. attention

    attention: 时序的刻画 attention 在recommendation 中的应用: 年龄的增长, 对于商品的喜好 Dynamic attention deeo model:

  7. tensorflow 控制流操作,条件判断和循环操作

    Control flow operations: conditionals and loops When building complex models such as recurrent neura ...

  8. 论文解读(GATv2)《How Attentive are Graph Attention Networks?》

    论文信息 论文标题:How Attentive are Graph Attention Networks?论文作者:Shaked Brody, Uri Alon, Eran Yahav论文来源:202 ...

  9. [论文阅读] RNN 在阿里DIEN中的应用

    [论文阅读] RNN 在阿里DIEN中的应用 0x00 摘要 本文基于阿里推荐DIEN代码,梳理了下RNN一些概念,以及TensorFlow中的部分源码.本博客旨在帮助小伙伴们详细了解每一步骤以及为什 ...

随机推荐

  1. Html表单标签:

    表单用于收集不同的类型的用户输入,表单由不同类型的标签组成,相关标签及属性如下: (1)<form>标签 定义整体的表单区域 -- action属性 定义表单数据提交址址 -- metho ...

  2. PAT Basic 1009

    1009 说反话 (20 分) 给定一句英语,要求你编写程序,将句中所有单词的顺序颠倒输出. 输入格式: 测试输入包含一个测试用例,在一行内给出总长度不超过 80 的字符串.字符串由若干单词和若干空格 ...

  3. ios开发中字符串的常用功能总结

    1.分割字符串 NSString * str1 = @"123/456"; NSArray * arr1 = [str1 componentsSeparatedByString:@ ...

  4. 外卖ERP管理系统(一)

    京门时代外卖ERP是北京京门时代科技有限公司旗下一款专业的外卖ERP系统管理软件. 本ERP目前己经对接了百度.饿了么.美团以等各外卖平台,在配送方面对接了闪送快递.人人快递.UU跑腿以及达达配送. ...

  5. CSS魔法堂:那个被我们忽略的outline

    前言  在CSS魔法堂:改变单选框颜色就这么吹毛求疵!中我们要模拟原生单选框通过Tab键获得焦点的效果,这里涉及到一个常常被忽略的属性--outline,由于之前对其印象确实有些模糊,于是本文打算对其 ...

  6. python下的selenium和PhantomJS

    一般我们使用python的第三方库requests及框架scrapy来爬取网上的资源,但是设计javascript渲染的页面却不能抓取,此时,我们使用web自动化测试化工具Selenium+无界面浏览 ...

  7. Canvas 和 SVG 的不同

    Canvas 和 SVG 都允许您在浏览器中创建图形,但是它们在根本上是不同的. SVG SVG 是一种使用 XML 描述 2D 图形的语言. SVG 基于 XML,这意味着 SVG DOM 中的每个 ...

  8. webstorm intelliJ IDEA phpstorm 设置鼠标滚动改变字体大小

    control+shift+A功能可以搜索对应功能,把mouse:Change font size(Zoom) ...的按钮打开,然后就可以通过 ctrl+鼠标上下滚动调节字体大小

  9. Python实现多进程

    Python可以实现多线程,但是因为Global Interpreter Lock (GIL),Python的多线程只能使用一个CPU内核,即一个时间只有一个线程在运行,多线程只是不同线程之间的切换, ...

  10. PHP 实现自动加载

    自动载入主要是省去了一个个类去 include 的繁琐,在 new 时动态的去检查并 include 相应的 class 文件. 先上代码: //index.php <?php class Cl ...