tensorflow学习笔记(三十九):双向rnn
tensorflow 双向 rnn
如何在tensorflow中实现双向rnn
单层双向rnn
单层双向rnn (cs224d)
tensorflow
中已经提供了双向rnn
的接口,它就是tf.nn.bidirectional_dynamic_rnn()
. 我们先来看一下这个接口怎么用.
bidirectional_dynamic_rnn(
cell_fw, #前向 rnn cell
cell_bw, #反向 rnn cell
inputs, #输入序列.
sequence_length=None,# 序列长度
initial_state_fw=None,#前向rnn_cell的初始状态
initial_state_bw=None,#反向rnn_cell的初始状态
dtype=None,#数据类型
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
返回值:一个tuple(outputs, outputs_states), 其中,outputs
是一个tuple(outputs_fw, outputs_bw). 关于outputs_fw
和outputs_bw
,如果time_major=True
则它俩也是time_major
的,vice versa. 如果想要concatenate
的话,直接使用tf.concat(outputs, 2)
即可.
如何使用:
bidirectional_dynamic_rnn 在使用上和 dynamic_rn
n是非常相似的. 定义前向和反向rnn_cell
定义前向和反向rnn_cell的初始状态
准备好序列
调用bidirectional_dynamic_rnn
import tensorflow as tf
from tensorflow.contrib import rnn
cell_fw = rnn.LSTMCell(10)
cell_bw = rnn.LSTMCell(10)
initial_state_fw = cell_fw.zero_state(batch_size)
initial_state_bw = cell_bw.zero_state(batch_size)
seq = ...
seq_length = ...
(outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq,
seq_length, initial_state_fw,initial_state_bw)
out = tf.concat(outputs, 2)
# ....
多层双向rnn
多层双向rnn(cs224d)
单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell
传给bidirectional_dynamic_rnn
了.
想要知道为什么,我们需要看一下bidirectional_dynamic_rnn
的源码片段.
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=fw_scope)
这只是一小部分代码,但足以看出,bi-rnn
实际上是依靠dynamic-rnn
实现的,如果我们使用MuitiRNNCell
的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn
来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出
bidirectional_dynamic_rnn源码一探
上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现.
其实反向的过程就是做了两次reverse
1. 第一次reverse
:将输入序列进行reverse
,然后送入dynamic_rnn
做一次运算.
2. 第二次reverse
:将上面dynamic_rnn
返回的outputs
进行reverse
,保证正向和反向输出的time
是对上的.
def _reverse(input_, seq_lengths, seq_dim, batch_dim):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_, seq_lengths=seq_lengths,
seq_dim=seq_dim, batch_dim=batch_dim)
else:
return array_ops.reverse(input_, axis=[seq_dim]) with vs.variable_scope("bw") as bw_scope:
inputs_reverse = _reverse(
inputs, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=bw_scope) output_bw = _reverse(
tmp, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim) outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw) return (outputs, output_states)
tf.reverse_sequence
对序列中某一部分进行反转
reverse_sequence(
input,#输入序列,将被reverse的序列
seq_lengths,#1Dtensor,表示输入序列长度
seq_axis=None,# 哪维代表序列
batch_axis=None, #哪维代表 batch
name=None,
seq_dim=None,
batch_dim=None
)
官网上的例子给的非常好,这里就直接粘贴过来:
# Given this:
batch_dim = 0
seq_dim = 1
input.dims = (4, 8, ...)
seq_lengths = [7, 2, 3, 5] # then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] # while entries past seq_lens are copied through:
output[0, 7:, :, ...] = input[0, 7:, :, ...]
output[1, 2:, :, ...] = input[1, 2:, :, ...]
output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]
例二:
# Given this:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5] # then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] # while entries past seq_lens are copied through:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
tensorflow学习笔记(三十九):双向rnn的更多相关文章
- tensorflow学习笔记(三十四):Saver(保存与加载模型)
Savertensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., nam ...
- 【Unity 3D】学习笔记三十九:控制组件
控制组件 角色控制组件和刚体组件都具备物理引擎的功能,须要绑定游戏对象才干实现对应的物理效果,而且同一个游戏对象中两者仅仅能存在一个,不能共存.刚体组件能够很精确的模拟现实世界中的一切物理效果,而角色 ...
- angular学习笔记(三十)-指令(7)-compile和link(2)
继续上一篇:angular学习笔记(三十)-指令(7)-compile和link(1) 上一篇讲了compile函数的基本概念,接下来详细讲解compile和link的执行顺序. 看一段三个指令嵌套的 ...
- angular学习笔记(三十)-指令(2)-restrice,replace,template
本篇主要讲解指令中的 restrict属性, replace属性, template属性 这三个属性 一. restrict: 字符串.定义指令在视图中的使用方式,一共有四种使用方式: 1. 元素: ...
- angular学习笔记(三十)-指令(10)-require和controller
本篇介绍指令的最后两个属性,require和controller 当一个指令需要和父元素指令进行通信的时候,它们就会用到这两个属性,什么意思还是要看栗子: html: <outer‐direct ...
- angular学习笔记(三十)-指令(7)-compile和link(1)
这篇主要讲解指令中的compile,以及它和link的微妙的关系. link函数在之前已经讲过了,而compile函数,它和link函数是不能共存的,如果定义了compile属性又定义link属性,那 ...
- angular学习笔记(三十)-指令(6)-transclude()方法(又称linker()方法)-模拟ng-repeat指令
在angular学习笔记(三十)-指令(4)-transclude文章的末尾提到了,如果在指令中需要反复使用被嵌套的那一坨,需要使用transclude()方法. 在angular学习笔记(三十)-指 ...
- angular学习笔记(三十)-指令(5)-link
这篇主要介绍angular指令中的link属性: link:function(scope,iEle,iAttrs,ctrl,linker){ .... } link属性值为一个函数,这个函数有五个参数 ...
- 【转】 Pro Android学习笔记(十九):用户界面和控制(7):ListView
目录(?)[-] 点击List的item触发 添加其他控件以及获取item数据 ListView控件以垂直布局方式显示子view.系统的android.app.ListActivity已经实现了一个只 ...
随机推荐
- 在winform中,禁止combobox随着鼠标一起滑动!
在winform中,如果form上或者是控件上有一个combobox控件,当你选择这个控件,当你鼠标移动其他地方,滑动鼠标时,这时combobox的选择值就会随之鼠标一起变化,如果你不想让comboB ...
- GIT生成 SSH Key步骤
//设置user.name和email 提交到git之后会显示用户名(在随意一个目录打开git-bash执行就行)Administrator@DESKTOP-BP3H0HS MINGW64 /d/mi ...
- Web性能优化——缓存
Ehcache: ehcache的配置文件ehcache.xml <?xml version="1.0" encoding="UTF-8"?> &l ...
- RQN 273 马棚问题 dp
PID273 / 马棚问题 2016-07-29 18:21:55 运行耗时:1624 ms 运行内存:16248 KB 题目描述 每天,小明和他的马外出,然后他们一边跑一边玩耍.当他们结束的时候, ...
- Ubuntu 16 安装redis客户端
https://snapcraft.io/redis-desktop-manager sudo snap install redis-desktop-manager 很好用! 支持模糊过滤,两边加星号 ...
- js 小秘密
1.RegExp 对象方法 test检索字符串中指定的值.返回 true 或 false. 支持正则表达式的 String 对象的方法
- python中的set类型
一. 定义 set是一个无序且不重复的元素集合 set和dict类似,是一组key的集合,但不存储value set有以下特性: 1. 由于key不能重复,所有set中没有重复的key 2. 元素为不 ...
- react-quill 富文本编辑器
适合react的一款轻量级富文本编辑器 1.http://blog.csdn.net/xiaoxiao23333/article/details/62055128 (推荐一款Markdown富文本编辑 ...
- 个人作业4——alpha阶段个人小结
一.个人总结 在alpha 结束之后, 每位同学写一篇个人博客, 总结自己的alpha 过程: 请用自我评价表:http://www.cnblogs.com/xinz/p/3852177.html 有 ...
- Spring 在xml配置里配置事务
事先准备:配置数据源对象用<bean>实例化各个业务对象. 1.配置事务管理器. <bean id="transactionManager" class=&quo ...