tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别

https://blog.csdn.net/u014365862/article/details/78238807

MachineLP的Github(欢迎follow):https://github.com/MachineLP

我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一个框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。

  1.  
    chunk_size = 256
  2.  
    chunk_n = 160
  3.  
    rnn_size = 256
  4.  
    num_layers = 2
  5.  
    n_output_layer = MAX_CAPTCHA*CHAR_SET_LEN # 输出层

单层rnn:

tf.contrib.rnn.static_rnn:

输入:[步长,batch,input]

输出:[n_steps,batch,n_hidden]

还有rnn中加dropout

  1.  
    def recurrent_neural_network(data):
  2.  
     
  3.  
    data = tf.reshape(data, [-1, chunk_n, chunk_size])
  4.  
    data = tf.transpose(data, [1,0,2])
  5.  
    data = tf.reshape(data, [-1, chunk_size])
  6.  
    data = tf.split(data,chunk_n)
  7.  
     
  8.  
    # 只用RNN
  9.  
    layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
  10.  
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)
  11.  
    outputs, status = tf.contrib.rnn.static_rnn(lstm_cell, data, dtype=tf.float32)
  12.  
    # outputs = tf.transpose(outputs, [1,0,2])
  13.  
    # outputs = tf.reshape(outputs, [-1, chunk_n*rnn_size])
  14.  
    ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
  15.  
     
  16.  
    return ouput

多层rnn:

tf.nn.dynamic_rnn:

输入:[batch,步长,input] 
输出:[batch,n_steps,n_hidden] 
所以我们需要tf.transpose(outputs, [1, 0, 2]),这样就可以取到最后一步的output

  1.  
    def recurrent_neural_network(data):
  2.  
    # [batch,chunk_n,input]
  3.  
    data = tf.reshape(data, [-1, chunk_n, chunk_size])
  4.  
    #data = tf.transpose(data, [1,0,2])
  5.  
    #data = tf.reshape(data, [-1, chunk_size])
  6.  
    #data = tf.split(data,chunk_n)
  7.  
     
  8.  
    # 只用RNN
  9.  
    layer = {'w_':tf.Variable(tf.random_normal([rnn_size, n_output_layer])), 'b_':tf.Variable(tf.random_normal([n_output_layer]))}
  10.  
    #1
  11.  
    # lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(rnn_size)
  12.  
    # outputs1, status1 = tf.contrib.rnn.static_rnn(lstm_cell1, data, dtype=tf.float32)
  13.  
     
  14.  
    def lstm_cell():
  15.  
    return tf.contrib.rnn.LSTMCell(rnn_size)
  16.  
    def attn_cell():
  17.  
    return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=keep_prob)
  18.  
    # stack = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(0, num_layers)], state_is_tuple=True)
  19.  
    stack = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(0, num_layers)], state_is_tuple=True)
  20.  
    # outputs, _ = tf.nn.dynamic_rnn(stack, data, seq_len, dtype=tf.float32)
  21.  
    outputs, _ = tf.nn.dynamic_rnn(stack, data, dtype=tf.float32)
  22.  
    # [batch,chunk_n,rnn_size] -> [chunk_n,batch,rnn_size]
  23.  
    outputs = tf.transpose(outputs, (1, 0, 2))
  24.  
     
  25.  
    ouput = tf.add(tf.matmul(outputs[-1], layer['w_']), layer['b_'])
  26.  
     
  27.  
    return ouput

tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别的更多相关文章

  1. 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)

    问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse)  # 构建 ...

  2. 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题

    这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...

  3. tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be replaced by tf.contrib.rnn.BasicLSTMCell.

    For Tensorflow 1.2 and Keras 2.0, the line tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be repl ...

  4. tensorflow教程:tf.contrib.rnn.DropoutWrapper

    tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...

  5. tf.contrib.rnn.LSTMCell 里面参数的意义

    num_units:LSTM cell中的单元数量,即隐藏层神经元数量.use_peepholes:布尔类型,设置为True则能够使用peephole连接cell_clip:可选参数,float类型, ...

  6. tensorflow笔记6:tf.nn.dynamic_rnn 和 bidirectional_dynamic_rnn:的输出,output和state,以及如何作为decoder 的输入

    一.tf.nn.dynamic_rnn :函数使用和输出 官网:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn 使用说明: A ...

  7. tf.nn.dynamic_rnn

    tf.nn.dynamic_rnn(cell,inputs,sequence_length=None, initial_state=None,dtype=None, parallel_iteratio ...

  8. TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu

    import tensorflow as tf # 22 scope (name_scope/variable_scope) from __future__ import print_function ...

  9. 第十六节,使用函数封装库tf.contrib.layers

    这一节,介绍TensorFlow中的一个封装好的高级库,里面有前面讲过的很多函数的高级封装,使用这个高级库来开发程序将会提高效率. 我们改写第十三节的程序,卷积函数我们使用tf.contrib.lay ...

随机推荐

  1. Boost StateChart实现状态机----秒表例程

    Boost 提供了状态机的实现接口,采用了CRTP技术实现,下面以秒表为例子实现一个状态机,这是一个官方的例子,也可以参考资料:Boost Statechart 庫,状态机的状态转换图如下所示: 实现 ...

  2. 使用 IntraWeb (25) - 基本控件之 TIWRegion

    这应该是 IW 中最重要的容器了, 和它同父的还有 TIWTabControl TIWRegion 所在单元及继承链: IWRegion.TIWRegion 主要成员: property Align: ...

  3. dubbox REST服务使用fastjson替换jackson

    上一节讲解了resteasy如何使用fastjson来替换默认的jackson,虽然dubbox内部采用的就是resteasy,但是大多数情况下,dubbox服务是一个独立的app,并不需要以war包 ...

  4. 【Go命令教程】10. go fix 与 go tool fix

    命令 go fix 会把指定 代码包 的所有 Go 语言源码文件中的旧版本代码修正为新版本的代码.这里所说的版本即 Go 语言的版本.代码包的所有 Go 语言源码文件不包括其子代码包(如果有的话)中的 ...

  5. Revit API封装一个通用函数“过名称找元素”

    感觉这个函数不错.通过这种方式寻找元素经常需要用到. )         {  ];         }         // cannot find it.         return null; ...

  6. WebLogic使用总结(六)——WebLogic创建虚拟主机和修改启动端口号

    一.在WebLogic中创建一个虚拟主机 找到虚拟主机面板,如下图所示:

  7. android中Bitmap的放大和缩小的方法

    android中Bitmap的放大和缩小的方法 时间 2013-06-20 19:02:34  CSDN博客原文  http://blog.csdn.net/ada168855/article/det ...

  8. 【转】Android开发在路上:少去踩坑,多走捷径

    本文是我订阅"腾讯大讲堂"公众帐号时,他们推送的一篇文章,但在腾讯大讲堂官网上我并没有找到这篇文章,不过其它专门"爬"公众号文章的网站倒是有.我觉得写的很不错. ...

  9. ArcGIS教程:树状图

    摘要 构造可显示特征文件里连续合并类之间的属性距离的树示意图(树状图). 使用方法 · 输入特征文件必须採用预定的特征文件格式. 特征文件可使用 Iso 聚类或创建特征工具来创建.该文件必须至少包括两 ...

  10. [rrdtool]监控和自己主动绘图,简单的监控.md

    如今想要监控服务的流量和并发数,但是又没那么多时间来写系统.其它的运维系统又不熟悉,于是就用现有的rrdtool shell做了个简单的监控界面,暂时用下,也算是个小实验把. rrdtool也是刚接触 ...