TensorFlow框架下的RNN实践小结
截至目前,TensorFlow的RNN APIs还处于Draft阶段。不过据官方解释,RNN的相关API已经出现在Tutorials里了,大幅度的改动应该是不大可能,现在入手TF的RNN APIs风险应该是不大的。
目前TF的RNN APIs主要集中在tensorflow.models.rnn中的rnn和rnn_cell两个模块。其中,后者定义了一些常用的RNN cells,包括RNN和优化的LSTM、GRU等等;前者则提供了一些helper方法。
创建一个基础的RNN很简单:
1 |
from tensorflow.models.rnn import rnn_cell |
2 |
cell = rnn_cell.BasicRNNCell(inputs, state) |
创建一个LSTM或者GRU的cell?
1 |
cell = rnn_cell.BasicLSTMCell(num_units) #最最基础的,不带peephole。 |
2 |
cell = rnn_cell.LSTMCell(num_units, input_size) #可以设置peephole等属性。 |
3 |
cell = rnn_cell.GRUCell(num_units) |
调用呢?
1 |
output, state = cell(input, state) |
这样自己按timestep调用需要设置variable_scope的reuse属性为True,懒人怎么做,TF也给想好了:
1 |
state = cell.zero_state(batch_size, dtype=tf.float32) |
2 |
outputs, states = rnn.rnn(cell, inputs, initial_state=state) |
再懒一点:
1 |
outputs, states = rnn.rnn(cell, inputs, dtype=tf.float32) |
怕overfit,加个Dropout如何?
1 |
cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=0.5, output_keep_prob=0.5) |
做个三层的带Dropout的网络?
1 |
cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5) |
2 |
cell = rnn_cell.MultiRNNCell([cell] * 3) |
3 |
inputs = tf.nn.dropout(inputs, 0.5) #给第一层单独加个Dropout。 |
一个坑——用rnn.rnn要按照timestep来转换一下输入数据,比如像这样:
1 |
inputs = [tf.reshape(t, (input_dim[0], 1)) for t in tf.split(1, input_dim[1], inputs)] |
rnn.rnn()的输出也是对应每一个timestep的,如果只关心最后一步的输出,取outputs[-1]即可。
注意一下子返回值的dimension和对应关系,损失函数和其它情况没有大的区别。
目前饱受诟病的是TF本身还不支持Theano中scan()那样可以轻松实现的不定长输入的RNN,不过有人反馈说Theano中不定长训练起来还不如提前给inputs加个padding改成定长的训练快。
TensorFlow框架下的RNN实践小结的更多相关文章
- TensorFlow框架(5)之机器学习实践
1. Iris data set Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据集,分为3类, ...
- TensorFlow框架(6)之RNN循环神经网络详解
1. RNN循环神经网络 1.1 结构 循环神经网络(recurrent neural network,RNN)源自于1982年由Saratha Sathasivam 提出的霍普菲尔德网络.RNN的主 ...
- python机器学习TensorFlow框架
TensorFlow框架 关注公众号"轻松学编程"了解更多. 一.简介 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运 ...
- AlexeyAB DarkNet YOLOv3框架解析与应用实践(五)
AlexeyAB DarkNet YOLOv3框架解析与应用实践(五) RNNs in Darknet 递归神经网络是表示随时间变化的数据的强大模型.为了更好地介绍RNNs,我强烈推荐Andrej K ...
- MySQL在Django框架下的基本操作(MySQL在Linux下配置)
[原]本文根据实际操作主要介绍了Django框架下MySQL的一些常用操作,核心内容如下: ------------------------------------------------------ ...
- 人工智能 tensorflow框架-->简介及安装01
简介:Tensorflow是google于2015年11月开源的第二代机器学习框架. Tensorflow名字理解:图形边中流动的数据叫张量(Tensor),因此叫Tensorflow 既 张量流动 ...
- 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集
一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...
- Tensorflow之MNIST的最佳实践思路总结
Tensorflow之MNIST的最佳实践思路总结 在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...
- 基于TensorFlow的循环神经网络(RNN)
RNN适用场景 循环神经网络(Recurrent Neural Network)适合处理和预测时序数据 RNN的特点 RNN的隐藏层之间的节点是有连接的,他的输入是输入层的输出向量.extend(上一 ...
随机推荐
- MVC与单元测试实践之健身网站(八)-统计分析
统计分析模块与之前的内容相对独立,用于记录并跟踪各部位围度的变化.还需提供对所作计划的分析,辅助使计划更合理. 一 围度记录 这儿可以记录各项身体围度指标,现在包括体重在内身体上上下下基本全部提供了 ...
- iOS开发NSDate、NSString、时间戳之间的转化
//将UTCDate(世界标准时间)转化为当地时区的标准Date(钟表显示的时间) //NSDate *date = [NSDate date]; 2018-03-27 06:54:41 +0000 ...
- Selenium Webdriver 动态设置 Proxy
Step 1: Visiting "about:config" driver.get("about:config"); Step 2 : Run script ...
- Java:[面向对象:继承,多态]
本文内容: 继承 多态 首发时期:2018-03-23 继承: 介绍: 如果多个类中存在相同的属性和行为,可以将这些内容抽取到单独一个类中,那么多个类(子类)无需再定义这些属性和行为,只要继承那个类( ...
- maven(二):创建一个可用的maven项目,完整过程
环境:eclipse4.5 (内置maven插件) 创建maven项目 文件菜单--新建--其他-- maven project 下一步 选择web 结构 group id: 指项目在maven本地 ...
- Spring MVC 的工作原理
引自:https://www.cnblogs.com/xiaoxi/p/6164383.html SpringMVC的工作原理图: SpringMVC流程 1. 用户发送请求至前端控制器Dispat ...
- centos-7 虚拟机安装图形界面
centos-7 虚拟机安装图形界面 想到安装一个docker环境,于是拿出了以前装的虚拟机centos7,记得装完后,没进行任何配置(默认安装的是命令行界面). 配置网络 现有的虚拟机是没有办法联网 ...
- C# 生成强命名程序集并添加到GAC
针对一些类库项目或用户控件项目(一般来说,这类项目最后编译生成的是一个或多个dll文件),在程序开发完成后,有时需要将开发的程序集(dll文件)安装部署到GAC(全局程序集缓存)中,以便其他的程序也可 ...
- MySQL sql_mode=only_full_group_by错误
今天在测试服务器上突然出现了这么一个MySQL的问题,同样的代码正式服没有问题,那肯定就是出在了配置上,查了一下原因才明白原来是数据库版本为5.7以上的版本, 默认是开启了 only_full_gro ...
- C语句详细(初学者)
C程序的执行部分是由语句组成的.程序的功能也是由执行语句实现的. C语句分为以下六类: 1.表达式语句:表达式加上分号“:”组成.执行表达式语句就是计算表达式的值. 2.函数调用语句:函数名.实际参数 ...