tensorflow源码分析——BasicLSTMCell
BasicLSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py。
BasicLSTMCell 继承了RNNCell,源码位于:/tensorflow/python/ops/rnn_cell_impl.py
注意事项:
1. input_size 这个参数不能使用,使用的是num_units
2. state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组
3. 输入、输出、cell的维度相同,都是 batch_size * num_units,
cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=0.0, state_is_tuple=True) #指定num_units
_initial_state = cell.zero_state(batch_size, tf.float32) #指定batch_size,将c和h全部初始化为0,shape全是batch_size * num_units,
4.
class BasicLSTMCell(RNNCell):
"""Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline. For advanced models, please use the full LSTMCell that follows.
""" def __init__(self, num_units, forget_bias=1.0, input_size=None,
state_is_tuple=True, activation=tanh):
"""Initialize the basic LSTM cell. Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation @property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units) @property
def output_size(self):
return self._num_units def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or "basic_lstm_cell"):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) # 线性计算 concat = [inputs, h]W + b
# 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数,
# concat shape(batch_size, 4*num_units)
# 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数
concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
5. lstm层,每一batch的运算
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
6. 每一epoch
7.全部运算
tensorflow源码分析——BasicLSTMCell的更多相关文章
- tensorflow源码分析
前言: 一般来说,如果安装tensorflow主要目的是为了调试些小程序的话,只要下载相应的包,然后,直接使用pip install tensorflow即可. 但有时我们需要将Tensorflow的 ...
- tensorflow源码分析——LSTMCell
LSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py.LSTMCell 继承了RNN ...
- 图解tensorflow 源码分析
http://www.cnblogs.com/yao62995/p/5773578.html https://github.com/yao62995/tensorflow
- tensorflow源码分析——CTC
CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurren ...
- [tensorflow源码分析] Conv2d卷积运算 (前向计算,反向梯度计算)
- [图解tensorflow源码] 入门准备工作附常用的矩阵计算工具[转]
[图解tensorflow源码] 入门准备工作 附常用的矩阵计算工具[转] Link: https://www.cnblogs.com/yao62995/p/5773142.html tensorf ...
- [图解tensorflow源码] 入门准备工作
tensorflow使用了自动化构建工具bazel.脚本语言调用c或cpp的包裹工具swig.使用EIGEN作为矩阵处理工具.Nvidia-cuBLAS GPU加速计算库.结构化数据存储格式prot ...
- [图解tensorflow源码] [原创] Tensorflow 图解分析 (Session, Graph, Kernels, Devices)
TF Prepare [图解tensorflow源码] 入门准备工作 [图解tensorflow源码] TF系统概述篇 Session篇 [图解tensorflow源码] Session::Run() ...
- TensorFlow源码框架 杂记
一.为什么我们需要使用线程池技术(ThreadPool) 线程:采用“即时创建,即时销毁”策略,即接受请求后,创建一个新的线程,执行任务,完毕后,线程退出: 线程池:应用软件启动后,立即创建一定数量的 ...
随机推荐
- RPC性能优化
优化 1:元数据共享 hessian 序列化会将两种信息写到输出流: 元数据:即类全名,字段名 值数据:即各个字段对应值(如果字段是复杂类型,则会递归传递该复杂类型 的元数据和内部字段的值数据) 在 ...
- ffmpeg处理视频命令
一:视频添加图片水印 ffmpeg -i a.mp4 -vf "movie=a.jpg[watermark];[in][watermark] overlay=main_w-overlay_w ...
- Delphi 10.3.2来了!
昨晚,官方正式发布了Delphi 10.3.2,增加对Mac 64应用的开发,支持Linux桌面开发,这个是通过集成fmxlinux实现的,同时修正400个bug,编译器,102个ide,84个fmx ...
- jvm监控工具jconsole进行远程监控配置
[环境] SUSE linux11 + jdk1.6 + tomcat7 [场景] 最近在做性能测试,想通过我本地(win7)上的jdk来远程监控上述服务器的jvm相关信息. [配置] 配置上述服务器 ...
- Oracle【三表的联合查询】
,'北京','彰显大气'); ,'上海','繁华都市'); ,'广州','凸显舒适'); ,'深圳','年轻气氛'); ,'北上广深','不相信眼泪'); commit; ; ; ; ; ; 员工信息 ...
- MFC的一些常用操作
一.添加消息 MFC和win32不同的一点是MFC采用的是消息的映射机制,即每一个消息都和处理的函数做了映射,我们可以通过查找的方式来得到消息的对应的函数,当然MFC提供了一个非常简便的方法,我们通过 ...
- insightface数据裁剪过程
数据裁剪 我们用lfw数据做实验,你也可以自己找数据. lfw数据 http://vis-www.cs.umass.edu/lfw/ 我下载的是这个原图像https://drive.google.co ...
- monkeyrunner录制和回放功能
脚本录制 网上先是搜索了一下,说是SDK--tools目录下有monkey_recorder.py和monkey_playback.py的脚本,但是我的没有找到所以可以自己编辑个脚本保存即可~ 先编辑 ...
- poj3691 DNA repair[DP+AC自动机]
$给定 n 个模式串,和一个长度为 m 的原串 s,求至少修改原串中的几个字符可以使得原串中不包含任一个模式串.模式串总长度 ≤ 1000,m ≤ 1000.$ 先建出模式串的AC自动机,然后考虑怎么 ...
- 2019CCPC秦皇岛赛区(重现赛)- I
链接: http://acm.hdu.edu.cn/contests/contest_showproblem.php?pid=1009&cid=872 题意: 在 dota2 中有一个叫做祈求 ...