tensorflow.nn.bidirectional_dynamic_rnn()函数的用法
在分析Attention-over-attention源码过程中,对于tensorflow.nn.bidirectional_dynamic_rnn()函数的总结:
首先来看一下,函数:
def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None, # 前向的初始化状态(可选)
initial_state_bw=None, # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False,
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
scope=None
)
返回值:
元组:(outputs, output_states)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
其中,
- outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设
 
time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。
- output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。
 
output_state_fw和output_state_bw的类型为LSTMStateTuple。
LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。
c_fw,h_fw = output_state_fw
c_bw,h_bw = output_state_bw
最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态。
# lstm模型 正方向传播的RNN
lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
# 反方向传播的RNN
lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
但是看来看去,输入两个cell都是相同的啊? 
其实在bidirectional_dynamic_rnn函数的内部,会把反向传播的cell使用array_ops.reverse_sequence的函数将输入的序列逆序排列,使其可以达到反向传播的效果。 
在实现的时候,我们是需要传入两个cell作为参数就可以了:
(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell,
embedded_chars, dtype=tf.float32)
embedded_chars为输入的tensor,[batch_szie, max_time, depth]。batch_size为模型当中batch的大小,应用在文
本中时,max_time可以为句子的长度(一般以最长的句子为准,短句需要做padding),depth为输入句子词向量的维度。
当为双向GRU时,跟LSTM类似:
  with tf.variable_scope('document', initializer=orthogonal_initializer()):#生成正交矩阵的初始化器。
    fwd_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)#变长动态RNN的实现
    back_cell = tf.contrib.rnn.GRUCell(FLAGS.hidden_size)
    doc_len = tf.reduce_sum(doc_mask, reduction_indices=1)#在第二维上压缩求和,可用来降维
    h, _ = tf.nn.bidirectional_dynamic_rnn(
        fwd_cell, back_cell, doc_emb, sequence_length=tf.to_int64(doc_len), dtype=tf.float32)
        #doc_len求得的结果可能是其他类型,然后将他转化为64为整型
        #doc_emb前面已经确定它的[batch_size,max_time,depth]
        #dype输出类型
h_doc = tf.concat(h, 2)
可参考:
变长双向rnn的正确使用姿势:https://blog.csdn.net/lijin6249/article/details/78955175
tensorflow.nn.bidirectional_dynamic_rnn()函数的用法:https://blog.csdn.net/wuzqChom/article/details/75453327
tensorflow.nn.bidirectional_dynamic_rnn()函数的用法的更多相关文章
- 【TensorFlow】tf.nn.embedding_lookup函数的用法
		
tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素.tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量 ...
 - tf.nn.embedding_lookup函数的用法
		
关于np.random.RandomState.np.random.rand.np.random.random.np.random_sample参考https://blog.csdn.net/lanc ...
 - Tensorflow踩坑之tf.nn.bidirectional_dynamic_rnn()报错 “ValueError: None values not supported.”
		
详细解决方法见链接:https://stackoverflow.com/questions/39808336/tensorflow-bidirectional-dynamic-rnn-none-val ...
 - Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作
		
使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...
 - Tensorflow Batch normalization函数
		
Tensorflow Batch normalization函数 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 stackoverflow上tensorflow实现BN的不同函数的 ...
 - Tensorflow教程(2)Tensorflow的常用函数介绍
		
1.tf.constant tf.constant用来定义一个常量,所谓常量,广义上讲就是“不变化的量”.我们先看下官方api是如何对constant函数来定义的: tf.constant( valu ...
 - 有关日期的函数操作用法总结,to_date(),trunc(),add_months();
		
相关知识链接: Oracle trunc()函数的用法 oracle add_months函数 Oracle日期格式转换,tochar(),todate() №2:取得当前日期是一个星期中的第几天,注 ...
 - Oracle to_date()函数的用法
		
Oracle to_date()函数的用法 to_date()是Oracle数据库函数的代表函数之一,下文对Oracle to_date()函数的几种用法作了详细的介绍说明,供您参考学习. 在Orac ...
 - js中bind、call、apply函数的用法
		
最近一直在用 js 写游戏服务器,我也接触 js 时间不长,大学的时候用 js 做过一个 H3C 的 web的项目,然后在腾讯实习的时候用 js 写过一些奇怪的程序,自己也用 js 写过几个的网站.但 ...
 
随机推荐
- PDO笔记
			
<?php/* * 查询操作主要是PDO::query().PDO::exec().PDO::prepare().PDO::query()主要是用于有记录结果返回的操作,特别是SELECT操作, ...
 - nodejs之Buffer
			
Buffer是什么? 简单点理解,buff就是固定长度的uint8array.(es6已实现TypedArray). 由于是固定长度所以没有了splice,concat方法. 由于是固定类型所以没有了 ...
 - 第161天:CSS3实现兼容性的渐变背景(gradient)效果
			
CSS实现兼容性的渐变背景(gradient)效果 一.有点俗态的开场白 在对CSS3支持日趋完善的今天,实现兼容性的渐变背景效果已经完全成为可能,本文就将展示如何实现兼容性的渐变背景效果.在众多的浏 ...
 - 第132天:移动web端-rem布局(进阶)
			
rem布局(进阶版) 该方案使用相当简单,把下面这段已压缩过的 原生JS(仅1kb,源码已在文章底部更新,2017/5/3) 放到 HTML 的 head 标签中即可(注:不要手动设置viewport ...
 - 对Spark2.2.0文档的学习3-Spark Programming Guide
			
Spark Programming Guide Link:http://spark.apache.org/docs/2.2.0/rdd-programming-guide.html 每个Spark A ...
 - 精通android学习笔记(一)---广播
			
普通广播:sendBroadcast 有序广播:sendOrderedBroadcast,有序广播优先级可以再manifest中设置,数值越大,最先收到.-1000~1000 <receiver ...
 - SpringBoot多数据源配置事务
			
除了消费降级,这将会是娱乐继续下沉的一年. 36氪从多个信源处获悉,资讯阅读应用趣头条已经完成了腾讯领投的Pre-IPO轮融资,交易金额预计达上亿美元,本轮融资估值在13-15亿美金之间:完成此轮融资 ...
 - 【UOJ#79】一般图最大匹配(带花树)
			
[UOJ#79]一般图最大匹配(带花树) 题面 UOJ 题解 带花树模板题 关于带花树的详细内容 #include<iostream> #include<cstdio> #in ...
 - Alpha 冲刺 —— 十分之一
			
队名 火箭少男100 组长博客 林燊大哥 作业博客 Alpha 冲鸭! 成员冲刺阶段情况 林燊(组长) 过去两天完成了哪些任务 协调各成员之间的工作,对多个目标检测及文字识别模型进行评估.实验,选取较 ...
 - BZOJ4408 [Fjoi 2016]神秘数  【主席树】
			
题目链接 BZOJ4408 题解 假如我们已经求出一个集合所能凑出连续数的最大区间\([1,max]\),那么此时答案为\(max + 1\) 那么我们此时加入一个数\(x\),假若\(x > ...