新代码在contrib\seq2seq\python\ops\attention_decoder_fn.py

和之前代码相比 不再采用conv的方式来计算乘,直接使用乘法和linear

给出了两种attention的实现 传统的"bahdanau": additive (Bahdanau et al., ICLR'2015) Neural Machine Translation by Jointly Learning to Align and Translate

以及"luong": multiplicative (Luong et al., EMNLP'2015) Effective Approaches to Attention-based Neural Machine Translation

这里以 bahdanau为例

还是按照 Grammar as a Foreign Language的公式

对应代码里面

将input encoder outputs 也就是输入的attention states作为 attention values

也就是在prepare_attention中

attention_values = attention_states

那么attention keys 对应 W_1h_i的部分,采用linear来实现

attention_keys = layers.linear(

attention_states, num_units, biases_initializer=None, scope=scope)

在创建score function的

_create_attention_score_fn 中完整定义了计算过程

这里去掉luong的实现部分 仅仅看bahdanau部分

with variable_scope.variable_scope(name, reuse=reuse):

if attention_option == "bahdanau":

#这里对应第一个公式最右面 query_w对应W_2, query是对应d_t

query_w = variable_scope.get_variable(

"attnW", [num_units, num_units], dtype=dtype)

#对应第一个公式最左侧的v

score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype)

def attention_score_fn(query, keys, values):

"""Put attention masks on attention_values using attention_keys and query.

Args:

query: A Tensor of shape [batch_size, num_units].

keys: A Tensor of shape [batch_size, attention_length, num_units].

values: A Tensor of shape [batch_size, attention_length, num_units].

Returns:

context_vector: A Tensor of shape [batch_size, num_units].

Raises:

ValueError: if attention_option is neither "luong" or "bahdanau".

"""

if attention_option == "bahdanau":

# transform query W_2*d_t

query = math_ops.matmul(query, query_w)

# reshape query: [batch_size, 1, num_units]

query = array_ops.reshape(query, [-1, 1, num_units])

# attn_fun 对应第一个公式的最左侧结果(=左侧) math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2]) * + reduce_sum操作即是dot操作

scores = _attn_add_fun(score_v, keys, query)

# Compute alignment weights

# scores: [batch_size, length]

# alignments: [batch_size, length]

# TODO(thangluong): not normalize over padding positions.

#对应第二个公式计算softmax结果

alignments = nn_ops.softmax(scores)

# Now calculate the attention-weighted vector.

alignments = array_ops.expand_dims(alignments, 2)

#利用softmax得到的权重 计算attention向量的加权加和

context_vector = math_ops.reduce_sum(alignments * values, [1])

context_vector.set_shape([None, num_units])

#context_vector即对应 第三个公式 =的左侧

return context_vector

再看下计算出contenxt_vector之后的使用,这个方法正如论文中所说也和之前旧代码基本一致

也就是说将context和query进行concat之后通过linear映射依然得到num_units的长度 作为attention

def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse):

"""Function to compute attention vectors.

Args:

name: to label variables.

num_units: hidden state dimension.

attention_score_fn: to compute similarity between key and target states.

reuse: whether to reuse variable scope.

Returns:

attention_construct_fn: to build attention states.

"""

with variable_scope.variable_scope(name, reuse=reuse) as scope:

def construct_fn(attention_query, attention_keys, attention_values):

context = attention_score_fn(attention_query, attention_keys,

attention_values)

concat_input = array_ops.concat([attention_query, context], 1)

attention = layers.linear(

concat_input, num_units, biases_initializer=None, scope=scope)

return attention

return construct_fn

最终的使用,cell_output就是attention,而next_input是cell_input和attention的concat

# construct attention

attention = attention_construct_fn(cell_output, attention_keys,

attention_values)

cell_output = attention

# argmax decoder

cell_output = output_fn(cell_output) # logits

next_input_id = math_ops.cast(

math_ops.argmax(cell_output, 1), dtype=dtype)

done = math_ops.equal(next_input_id, end_of_sequence_id)

cell_input = array_ops.gather(embeddings, next_input_id)

# combine cell_input and attention

next_input = array_ops.concat([cell_input, attention], 1)

Dynamic attention in tensorflow的更多相关文章

  1. 论文翻译:2020_A Recursive Network with Dynamic Attention for Monaural Speech Enhancement

    论文地址:基于动态注意的递归网络单耳语音增强 论文代码:https://github.com/Andong-Li-speech/DARCN 引用格式:Li, A., Zheng, C., Fan, C ...

  2. Dynamic seq2seq in tensorflow

    v1.0中 tensorflow渐渐废弃了老的非dynamic的seq2seq接口,已经放到 tf.contrib.legacy_seq2seq目录下面. tf.contrib.seq2seq下面的实 ...

  3. 可视化展示attention(seq2seq with attention in tensorflow)

    目前实现了基于tensorflow的支持的带attention的seq2seq.基于tf 1.0官网contrib路径下seq2seq 由于后续版本不再支持attention,迁移到melt并做了进一 ...

  4. Effective Tensorflow[转]

    Effective TensorFlow Table of Contents TensorFlow Basics Understanding static and dynamic shapes Sco ...

  5. seq2seq attention

    1.seq2seq:分为encoder和decoder a.在decoder中,第一时刻输入的是上encoder最后一时刻的状态,如果用了双向的rnn,那么一般使用逆序的最后一个时刻的输出(网上说实验 ...

  6. attention

    attention: 时序的刻画 attention 在recommendation 中的应用: 年龄的增长, 对于商品的喜好 Dynamic attention deeo model:

  7. tensorflow 控制流操作,条件判断和循环操作

    Control flow operations: conditionals and loops When building complex models such as recurrent neura ...

  8. 论文解读(GATv2)《How Attentive are Graph Attention Networks?》

    论文信息 论文标题:How Attentive are Graph Attention Networks?论文作者:Shaked Brody, Uri Alon, Eran Yahav论文来源:202 ...

  9. [论文阅读] RNN 在阿里DIEN中的应用

    [论文阅读] RNN 在阿里DIEN中的应用 0x00 摘要 本文基于阿里推荐DIEN代码,梳理了下RNN一些概念,以及TensorFlow中的部分源码.本博客旨在帮助小伙伴们详细了解每一步骤以及为什 ...

随机推荐

  1. PAT Basic 1020

    1020 月饼 (25 分) 月饼是中国人在中秋佳节时吃的一种传统食品,不同地区有许多不同风味的月饼.现给定所有种类月饼的库存量.总售价.以及市场的最大需求量,请你计算可以获得的最大收益是多少. 注意 ...

  2. IDEA15 创建javaweb 并配置Tomcat(转)

    1.打开IDEA选择Web Application 2.在WEB-INF下创建如图两个文件夹 3.在右上角找到这个,准备配置Tomcat 4.在弹出的窗口里像这样配置LocaleHost 5.设置好本 ...

  3. python之迭代器篇

    一.迭代器 只要对象本身有_iter_()_方法,那它就是可迭代的 执行__iter__就会生成迭代器 迭代器有__next__用于获取值 __next__超出界限了会报StopIteration异常 ...

  4. Linux之通配符实验

    作业五:通配符实验 反引号与()在此时都是表死获取结果 但是一般使用()的方式,因为反引号在多个反引号的时候无法正确指代 获取当前bash 的变量 echo $变量名 echo $? 表示上一次命令的 ...

  5. Miller Rabin素数检测

    #include<iostream> #include<cstdio> #include<queue> #include<cstring> #inclu ...

  6. 加速Android Studio编译速度

    一.修改运行内存 进入项目,菜单栏-help-Edit Custom VM Option   Paste_Image.png 添加或修改为: -Xms2048m -Xmx2048m -XX:MaxPe ...

  7. C# Monitor实现

    Monitor的code如下,非常简单: public static class Monitor { public static extern void Enter(Object obj); publ ...

  8. RGBA alpha 透明度混合算法实现和测试

    目录 1.算法叙述 1.1.透明度混合算法1 1.3.简易Alpha混合算法 2.算法实现代码和测试 2.1.透明度混合算法1实现代码 2.1.AlphaBlend算法实现代码 2.3.测试截图 2. ...

  9. Reading table information for completion of table and column names You can turn off this feature to get a quicker startup with -

    mysql -A不预读数据库信息(use dbname 更快)—Reading table information for completion of table and column names Y ...

  10. Effective Java 第三版——69. 仅在发生异常的条件下使用异常

    Tips 书中的源代码地址:https://github.com/jbloch/effective-java-3e-source-code 注意,书中的有些代码里方法是基于Java 9 API中的,所 ...