tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别
tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别
https://blog.csdn.net/u014365862/article/details/78238807
MachineLP的Github(欢迎follow):https://github.com/MachineLP
我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。
- chunk_size = 256
- chunk_n = 160
- rnn_size = 256
- num_layers = 2
- n_output_layer = MAX_CAPTCHA*CHAR_SET_LEN # 输出层
单层rnn:
tf.contrib.rnn.static_rnn:
输入:[步长,batch,input]
输出:[n_steps,batch,n_hidden]
还有rnn中加dropout
- def recurrent_neural_network(data):
- data = tf.reshape(data, [-1, chunk_n, chunk_size])
- data = tf.transpose(data, [1,0,2])
- data = tf.reshape(data, [-1, chunk_size])
- data = tf.split(data,chunk_n)
- # 只用RNN
- layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
- lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)
- outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)
- # outputs = tf.transpose(outputs, [1,0,2])
- # outputs = tf.reshape(outputs, [-1, chunk_n*rnn_size])
- ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
- return ouput
多层rnn:
tf.nn.dynamic_rnn:
输入:[batch,步长,input]
输出:[batch,n_steps,n_hidden]
所以我们需要tf.transpose(outputs, [1, 0, 2]),这样就可以取到最后一步的output
- def recurrent_neural_network(data):
- # [batch,chunk_n,input]
- data = tf.reshape(data, [-1, chunk_n, chunk_size])
- #data = tf.transpose(data, [1,0,2])
- #data = tf.reshape(data, [-1, chunk_size])
- #data = tf.split(data,chunk_n)
- # 只用RNN
- layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
- #1
- # lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(rnn_size)
- # outputs1, status1 = tf.contrib.rnn.static_rnn(lstm_cell1, data, dtype=tf.float32)
- def lstm_cell():
- return tf.contrib.rnn.LSTMCell(rnn_size)
- def attn_cell():
- return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=keep_prob)
- # stack = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(0, num_layers)], state_is_tuple=True)
- stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, num_layers)], state_is_tuple=True)
- # outputs, _ = tf.nn.dynamic_rnn(stack, data, seq_len, dtype=tf.float32)
- outputs, _ = tf.nn.dynamic_rnn(stack, data, dtype=tf.float32)
- # [batch,chunk_n,rnn_size] -> [chunk_n,batch,rnn_size]
- outputs = tf.transpose(outputs, (1, 0, 2))
- ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
- return ouput
tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别的更多相关文章
- 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)
问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse) # 构建 ...
- 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题
这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...
- tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be replaced by tf.contrib.rnn.BasicLSTMCell.
For Tensorflow 1.2 and Keras 2.0, the line tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be repl ...
- tensorflow教程:tf.contrib.rnn.DropoutWrapper
tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...
- tf.contrib.rnn.LSTMCell 里面参数的意义
num_units:LSTM cell中的单元数量,即隐藏层神经元数量.use_peepholes:布尔类型,设置为True则能够使用peephole连接cell_clip:可选参数,float类型, ...
- tensorflow笔记6:tf.nn.dynamic_rnn 和 bidirectional_dynamic_rnn:的输出,output和state,以及如何作为decoder 的输入
一.tf.nn.dynamic_rnn :函数使用和输出 官网:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn 使用说明: A ...
- tf.nn.dynamic_rnn
tf.nn.dynamic_rnn(cell,inputs,sequence_length=None, initial_state=None,dtype=None, parallel_iteratio ...
- TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu
import tensorflow as tf # 22 scope (name_scope/variable_scope) from __future__ import print_function ...
- 第十六节,使用函数封装库tf.contrib.layers
这一节,介绍TensorFlow中的一个封装好的高级库,里面有前面讲过的很多函数的高级封装,使用这个高级库来开发程序将会提高效率. 我们改写第十三节的程序,卷积函数我们使用tf.contrib.lay ...
随机推荐
- Homebrew安装卸载
安装homebrew ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/inst ...
- Docker系列之(五):使用Docker Compose编排容器
1. 前言 Docker Compose 是 Docker 容器进行编排的工具,定义和运行多容器的应用,可以一条命令启动多个容器. 使用Compose 基本上分为三步: Dockerfile 定义应用 ...
- httpwatch抓包工具的使用方法
火狐浏览器下有著名的httpfox,而HttpWatch则是IE下强大的网页数据分析工具.这个工具到底有哪些具体功能呢?这个我就不再赘述了,百度百科上列的很全面,但也比较抽象.我只想说我曾经用这个工具 ...
- STM32 CRC-32 Calculator Unit
AN4187 - Using the CRC peripheral in the STM32 family At start up, the algorithm sets CRC to the Ini ...
- Exynos4412的外部中断是如何安排的?
作者 彭东林 pengdonglin137@163.com 平台 Linux4.9 tiny4412 概述 结合tiny4412开发板分析一下Exynos4412的外部中断是如何组织的. ...
- Unity3D实践系列02,查看Scene窗口物体
删除"Hierarchy"窗口中的"Directional Light". 把鼠标放在"Scene"窗口,滑动鼠标滚轮,可以对"S ...
- Java异常(三) 《Java Puzzles》中关于异常的几个谜题
概要 本章介绍<Java Puzzles>中关于异常的几个谜题.这一章都是以代码为例,相比上一章看起来更有意思.内容包括:谜题1: 优柔寡断谜题2: 极端不可思议谜题3: 不受欢迎的宾客谜 ...
- 关于电商ERP的想法
原文地址: http://www.chinaodoo.net/thread-465-1-1.html 试用了下odoo的淘宝订单处理模块,从整个业务流程上已经打通,如果要求不是很高的话,现有的功能基本 ...
- [开源]Google code Android开源项目(一)
[Android分享] [开源]Google code Android开源项目(一) [复制链接] 449122717 2 主题 2 好友 816 积分 No.4 中级开发者 升级 19.3 ...
- 关于面试总结11-selenium面试题
前言 面试web自动化必然会问到selenium,问selenium相关的问题定位是最基本的,也是自动化的根本,所以面试离不开元素定位问题. 之前看到招聘要求里面说"只会复制粘贴xpath的 ...