Attribute 'num_units' in Tensorflow BasicLSTMCell blocks
在之前使用Tensorflow来做音乐识别时,LSTM给出了非常让人惊喜的学习能力。当时在进行Tuning的时候,有一个参数叫做num_units,字面看来是LTSM单元的个数,但最近当我试图阅读Tensorflow源代码时,和我们最初的认知大不相同,以此博文来记录。
先看当初我们是如何设置的:
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=300)
看起来像是,为Hidden Layer设置了300个单独的LSTM单元,然后并行工作最终输出300个值。但实际上,我们来看一下Tensorflow的源码:(github地址),从line 326,开始定义BasicLSTMCell类,在line 374行开始定义BasicLSTMCell的核心方法call方法:
def call(self, inputs, state):
"""Long short-term memory cell (LSTM)."""
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) concat = _linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
注意13行,改行的作用是,根据当前时刻的输入inputs,以及前一时刻的输出值h,去计算4个gates在经过activation function之前的线性组合值。而后15-17两行,我们使用四个gates去计算了新的LSTM Cell状态c,以及新的输出值h。
是的,无论num_units设置为多少,这是一个LSTM Cell!如果我们查看_linear这个函数,可以看到第二个参数是output_size,也就是说num_units和LSTM Cell的输出大小有关。事实上,Tensorflow的LSTMCell表征了整个一层Hidden Layer。而num_units则表示State Cell的存储能力,或者说维度Dimension。试想在一个LSTM Neural Network中,输入tensor X的维度是确定的,输出值Y的维度也是确定的,而LSTM各个时刻间的中间状态c,以及抽象输出h,则可以为任意维度。因为h可以经过dense层(fully-connected layer)去压缩成Y所需的维度。
所以c和h的维度越高,其蕴含的time series data细节越多,当然越容易去拟合training set。但是,容易Overfitting呀,所以tuning时平衡training set的拟合程度,以及cv set的预测精度,来达到trade off咯。
Attribute 'num_units' in Tensorflow BasicLSTMCell blocks的更多相关文章
- AttributeError: module 'tensorflow' has no attribute 'enable_eager_execution'
Traceback (most recent call last): File "linear_regression_eager_api.py", line 15, in < ...
- Sphinx 2.2.11-release reference manual
1. Introduction 1.1. About 1.2. Sphinx features 1.3. Where to get Sphinx 1.4. License 1.5. Credits 1 ...
- 转:用AutoCAD 系统变量编程
Autocad的系统变量, 我们可以通过如下得到: Autodesk.AutoCAD.ApplicationServices.Application.GetSystemVariable(/*MSG0* ...
- ubuntu14.04 安装 tensorflow9.0
ubuntu14.04 安装 tensorflow9.0 文章目录 ubuntu14.04 安装 tensorflow9.0 安装pip(笔者的版本为9.0) 仅使用 CPU 的版本的tensorfl ...
- chattr lsattr linux file system attributes - linux 文件系统扩展属性
我们使用 linux 文件系统扩展属性,能够对linux文件系统进行进一步保护:从而给文件 赋予一些额外的限制:在有些情况下,能够对我们的系统提供保护: chattr命令用来改变文件属性.这项指令可改 ...
- 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题
这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...
- tensorflow源码分析——BasicLSTMCell
BasicLSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py.BasicLSTMC ...
- 'tensorflow' has no attribute 'sub'
在学习tensorflow的时候,照到官方的例子做,发现了一个 Traceback (most recent call last): File , in <module> sub = tf ...
- 【pycharm】pycharm上安装tensorflow,报错:AttributeError: module 'pip' has no attribute 'main' 解决方法
pycharm上安装tensorflow,报错:AttributeError: module 'pip' has no attribute 'main' 解决方法 解决方法: 在pycharm的安装目 ...
随机推荐
- FB相关
1.传包过程的错误 (中文提示)游戏必须通过我们的 CDN 引用我们支持的 SDK (英文提示)Games must reference one of our supported SDKs via o ...
- C# System.Web.Caching.Cache类 缓存 各种缓存依赖
原文:https://www.cnblogs.com/kissdodog/archive/2013/05/07/3064895.html Cache类,是一个用于缓存常用信息的类.HttpRuntim ...
- JavaScript、ES6中类的this指向问题
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- websocket无法注入bean问题解决方案
websocket服务端往往需要和服务层打交道,因此需要将服务层的一些bean注入到websocket实现类中使用,但是呢,websocket实现类虽然顶部加上了@Component注解,依然无法通过 ...
- unity2017 光照与渲染(二)FAQs
FAQ: 场景里的物体没有影子? 1)灯光是否开了影子 2)QualitySettings 中 shadows 的设置 3) 模型MeshRenderer 的 ReciveShadows 和 Cast ...
- Express 2015 RC for Windows 10 安装
支持的操作系统 Windows 10 Technical Preview 硬件要求 1.6 GHz 或更快的处理器 1 GB RAM(如果在虚拟机上运行,则为 1.5 GB) 4 GB 可用硬盘空间 ...
- centos误删除文件如何恢复
当意识到误删除文件后,切忌千万不要再频繁写入了,否则你的数据恢复的数量将会很少. 而我们要做的是,第一时间把服务器上的服务全部停掉,直接killall 进程名 或者 kill -9 pid . 然后把 ...
- 基于TMS320C6678、FPGA XC5VLX110T的6U CPCI 8路光纤信号处理卡
基于TMS320C6678.FPGA XC5VLX110T的6U CPCI 8路光纤信号处理卡 1.板卡概述 本板卡由我公司自主研发,基于CPCI架构,符合CPCI2.0标准,采用两片TI DSP T ...
- 一、Core的布局页、起始页及错误页
一.布局页面: 使用布局页相当于一个母版页,可以将各个页面公用部分,如上方标题区.左侧导航菜单区.下方版权声明及状态显示区以及通用的js及css引用等,集中放到布局页管理,具体功能页面只需关注自己独有 ...
- 树形dp专栏
前言 自己树形dp太菜了,要重点搞 219D Choosing Capital for Treeland 终于自己做了一道不算那么毒瘤的换根dp 令 \(f[u]\) 表示以 \(u\) 为根,子树内 ...