【TensorFlow】自主实现包含全节点Cell的LSTM层 Cell
0x00 前言
常用的LSTM,或是双向LSTM,输出的结果通常是以下两个:
1) outputs,包括所有节点的hidden
2) 末节点的state,包括末节点的hidden和cell
大部分任务有这些就足够了,state是随着节点间信息的传递依次变化并容纳更多信息,
所以通常末状态的cell就囊括了所有信息,不需要中间每个节点的cell信息,
但如果我们的研究过程中需要用到这些cell该如何是好呢?
近期的任务中,需要每个节点的前后节点cell信息来做某种判断,
所以属于一个较为特殊的任务,自主实现了一下这个同样也会反馈cell的LSTM,
哦顺带一提Cell-Holding,是强行为了简称成CHD取的名字(笑)
0x01 分析与设计
首先分析源码,看一下通常LSTM层调用使用 dynamic_rnn 的实现逻辑,
原逻辑大概是这样的:
|
|
那么其实……我们只需要重新实现一个简化的版本,让cell留下来即可。
此处使用的逻辑大概是这样的:
|
|
为了实现这些,就需要做到以下几件事情:
1) 获取或共享已有LSTM层的BasicLSTMCell
2) 编写Cell相关计算,保留LSTM计算途中的信息,可自定义获取输出的格式
3) 采用设计的输出格式使用这些节点信息,以完成其他任务
0x02 Source Code
Advanced LSTM Layer
[LstmLayer] in tf_layers
首先要在不影响功能的情况下改写原有的LSTM Layer,令其支持获取BasicCell的操作
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 class LstmLayer(object):true# based on LSTM Layer, thanks for @lhw446def __init__(self, input_dim, num_units, sequence_length=None, bidirection=False, name="lstm"):self.input_dim = input_dimself.num_units = num_unitsself.bidirection = bidirectionself.sequence_length = sequence_lengthself.name = name# `with ... as...` remains assignment work.self.lstm_fw_cell = Noneself.lstm_bw_cell = Nonewith tf.name_scope('%s_def' % (self.name)):self.lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)if self.bidirection:self.lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)def __call__(self, inputs, sequence_length=None, time_major=False,initial_state_fw=None, initial_state_bw=None):inputs_shape = tf.shape(inputs)inputs = tf.reshape(inputs, [-1, inputs_shape[-2], self.input_dim])sequence_length = self.sequence_length if sequence_length is Noneelse tf.reshape(sequence_length, [-1])if initial_state_fw is not None:initial_state_fw = tf.nn.rnn_cell.LSTMStateTuple(tf.reshape(initial_state_fw[0], [-1, self.num_units]),tf.reshape(initial_state_fw[1], [-1, self.num_units]))if initial_state_bw is not None:initial_state_bw = tf.nn.rnn_cell.LSTMStateTuple(tf.reshape(initial_state_bw[0], [-1, self.num_units]),tf.reshape(initial_state_bw[1], [-1, self.num_units]))resh_1 = lambda tensors: tf.reshape(tensors, tf.concat([inputs_shape[:-1], [tf.shape(tensors)[-1]]], 0))resh_2 = lambda tensors: tf.reshape(tensors, tf.concat([inputs_shape[:-2], [tf.shape(tensors)[-1]]], 0))with tf.variable_scope('%s_cal' % (self.name)):if self.bidirection:outputs, output_states = tf.nn.bidirectional_dynamic_rnn(self.lstm_fw_cell, self.lstm_bw_cell, inputs,sequence_length=sequence_length,initial_state_fw=initial_state_fw,initial_state_bw=initial_state_bw,time_major=time_major, dtype=tf.float32)# (fw_outputs, bw_outputs)大专栏 【TensorFlow】自主实现包含全节点Cell的LSTM层 Cell> outputs = tf.nn.rnn_cell.LSTMStateTuple(resh_1(outputs[0]), resh_1(outputs[1]))# ((fw_c_states, fw_m_states), (bw_c_states, bw_m_states))output_states = tf.nn.rnn_cell.LSTMStateTuple(tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[0][0]), resh_2(output_states[0][1])),tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[1][0]), resh_2(output_states[1][1])))else:outputs, output_states = tf.nn.dynamic_rnn(self.lstm_fw_cell, inputs, sequence_length=sequence_length,initial_state=initial_state_fw,time_major=time_major, dtype=tf.float32)outputs = resh_1(outputs) # (outputs)# (c_states, m_states)output_states = tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[0]), resh_2(output_states[1]))return outputs, output_states
Cell-HolDing Layer
chd_lstm_layer in network
然后基于目标LSTM层,构建使用相同基本单元的scope,设定初始零状态,逐层计算
(此处仅剪枝了所有的padding位,没有特意做加速,用了简单的python-like的for循环)
(且为了本次实验需要,没有将hidden和cell区分开来,而是直接保存了state整体,可自行修改)
123456789101112131415161718 def chd_lstm_layer(self, inputs, target_layer):cell = target_layer.lstm_fw_cellwith tf.variable_scope('%s_cal' % (target_layer.name)):# generate initial states for current inputsstates_case = []for batch_idx in range(self.batch_size):batch_state_case = []state = cell.zero_state(1, tf.float32)for time_step in range(self.seg_len[batch_idx]):tf_input = inputs[batch_idx, time_step]output, _state = cell(tf.reshape(tf_input, [1, -1]), state)batch_state_case.append(_state)state = _statestates_case.append(batch_state_case)# a nested list of states [batch_size, seg_len]return states_case, cell
上述是任务需要,
主要演示了可以简单的循环调用给定LSTM层的Cell进行计算,
在对齐的情况下还可以通过stack等操作拼成一个tf的矩阵使用。
其中用作循环迭代次数的参数 self.batch_size self.seg_len等,
不可以是tf.placeholder,因为range内必须为一个固定的数值而不能为一个占位符(tf.loop不知道能不能做到)
所以在feed_dict前,我做了如下的操作,将这些固定数值作为 instance_variables 传给网络以供使用。
|
|
Further usage on states_case
others_layer in network
获取了states_case之后,可以用于各个位置的使用
下文中给出一个使用案例,此处用于计算相同LSTM序列中,替换其中任意节点为其他节点的输出。
1234567891011121314151617181920212223242526 def replace_layer(self, forward_emb, candidate_emb):backward_emb = self.get_reverse(forward_emb, rev_length=self.cell_lens + 2)fw_states, fw_cell = self.chd_lstm_layer(forward_emb, self.forward_lstm)bw_states, bw_cell = self.chd_lstm_layer(backward_emb, self.backward_lstm)hidden_case = []for batch_idx in range(self.batch_size):batch_case = []for time_step in range(self.seg_len[batch_idx]):time_case = []for candidate_idx in range(self.can_len[batch_idx, time_step]):tf_input = candidate_emb[batch_idx, time_step, candidate_idx]fw_hidden, _ = fw_cell(tf.reshape(tf_input, [1, -1]),fw_states[batch_idx][time_step])bw_hidden, _ = bw_cell(tf.reshape(tf_input, [1, -1]),bw_states[batch_idx][-time_step])hidden = tf.concat([fw_hidden, bw_hidden], -1)time_case.append(hidden)batch_case.append(time_case)hidden_case.append(batch_case)return hidden_case # a nested list.
0x03 后记
cell因其持续更新且后者包含前者信息的特性通常不被保存,
但是 LSTMCell RNNCell 的调用却需要完整的state(包括hidden和cell),
在我们对已经计算完毕的LSTM序列中内部的某些节点有所想法时,就很难回溯了,
所以说不定这种layer也是有一定价值的,目前tensorflow里还没有整合成类似的层,
所以自行手写了一个,虽说不是太复杂,不过提供了这样一种想法,记录一下~
(说不定以后就加了这个层呢~ 到时候我可以指着这篇文章说我早就想到咯^_^)
【TensorFlow】自主实现包含全节点Cell的LSTM层 Cell的更多相关文章
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- jQuery 获取当前节点的html包含当前节点的方法
在开发过程中,jQuery.html() 是获取当前节点下的html代码,并不包含当前节点本身的代码,然后我们有时候确需要,找遍jQuery api文档也没有任何方法可以拿到. 看到有的人通过pare ...
- jquery 获取 outerHtml 包含当前节点本身的代码
在开发过程中,jQuery.html() 是获取当前节点下的html代码,并不包含当前节点本身的代码,然后我们有时候确需要,找遍jQuery api文档也没有任何方法可以拿到. 看到有的人通过pare ...
- 比特币全节点(bitcoind) eth 全节点
运行全节点的用途: 1.挖矿 2.钱包 运行全节点,可以做关于btc的任何事情,例如创建钱包地址.管理钱包地址.发送交易.查询全网的交易信息等等 选个节点钱包:bitcoind 1.配置文件: ...
- 以太坊geth主网全节点部署
以太坊geth主网全节点部署 #环境 ubuntu 16.4 #硬盘500GB(目前占用200G) #客户端安装 # 查看下载页面最新版 # https://ethereum.github.io/go ...
- 比特币BTC全节点搭建
比特币BTC全节点搭建 #环境 ubuntu 16.4 #硬盘500GB #截止2018-12-31磁盘占用超过230GB #客户端安装 #下载页面 #https://bitcoin.org/zh_C ...
- 以太坊go-ethereum客户端(三)两种全节点启动模式
这篇博客介绍一下go-ethereum全节点的两种启动模式:主网络快速启动和测试网络快速启动.这也是客户端所提供的两种启动方式,直接拿来使用即可.下面具体介绍一下使用方法. 主网络快速启动 其实,我们 ...
- 100万套PPT模板,包含全宇宙所有主题类型PPT,绕宇宙100圈,持续更新
100万套PPT模板,包含全宇宙所有主题类型PPT(全部免费,都是精品,没有一张垃圾不好看的PPT,任何一张PPT拿来套入自己的信息就可以立马使用),绕宇宙100圈,任意一个模板在某文库上都价不菲.强 ...
- JS获取包含当前节点本身的代码内容(outerHtml)
原生JS DOM的内置属性 outerHTML 可用来获取当前节点的html代码(包含当前节点),且此属性可使用jQuery的prop()获取 <div id="demo-test-0 ...
随机推荐
- 简单的tab栏切换
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- ubuntu 卸载软件
ubuntu完全卸载一个软件 今天卸载一个软件,老是有配置残留,网上找到了解决方案: 查看已安装的软件: dpkg -l |grep 软件名 找到一大堆相关的包,然后卸载核心的包: sudo ap ...
- 14 微服务电商【黑马乐优商城】:day02-springcloud(理论篇四:配置Robbin负载均衡)
本项目的笔记和资料的Download,请点击这一句话自行获取. day01-springboot(理论篇) :day01-springboot(实践篇) day02-springcloud(理论篇一) ...
- 许家印67亿买下FF恒大是要雪中送炭吗?
从大环境来看,当下新能源汽车已经是备受投资者青睐的领域.据不完全统计,当下国内已经有300余家电动汽车企业.而蔚来.小鹏.威马等动辄都融资上百亿元,显现出火爆的发展趋势.甚至就连董明珠董大姐也有着自己 ...
- Tokyocabinet/Tokyotyrant文档大合集
1. 前言 这里不是我个人原创,是我对网络上整理到的资料的再加工,以更成体系,更方便研究阅读.主要是对其中跟主题无关的文字删除,部分人称稍做修改;本人无版权,您可以将本页面视为对参考页面的镜像.第二部 ...
- Office 365 邮件流
进入Exchange管理中心->点击左侧的“邮件流”->进入邮件流配置页面. 一.规则 规则也称传输规则,对通过组织传递的邮件,根据设定条件进行匹配,并对其进行操作.传输规则与众多电子邮件 ...
- 没有更好的,五种操作系统助力研发,IMX6开发板做得到
核心板参数 尺寸 51mm*61mm 四核商业级-2G NXP 四核 i.MX6Q,主频 1 GHz 内存:2GB DDR3:存储:16GB EMMC:SATA接口:支持 双核商业级-1G NXP 双 ...
- springboot支付项目之springboot集成jpa
springboot集成spring-jpa 本文主要内容: 1:spring boot怎么集成spring-jpa以及第一个jpa查询示例 如jpa几个常用注解.lombok注解使用 2:怎么设置i ...
- python实现个人信息随机生成
""" 生成随机姓名.电话号码.身份证号.性别.应行卡号.邮箱 """ import random from firstname impor ...
- redis的集群:
集群策略:主从复制哨兵集群 参考:https://blog.csdn.net/q649381130/article/details/79931791 集群又分为如下:客户端分片基于代理的分片路由查询参 ...