【转载】 【TensorFlow】static_rnn 和dynamic_rnn的区别
原文地址:
https://blog.csdn.net/qq_20135597/article/details/88980975
---------------------------------------------------------------------------------------------
tensorflow中提供了rnn接口有两种,一种是静态的rnn,一种是动态的rnn
通常用法:
1、静态接口:static_rnn
主要使用 tf.contrib.rnn
x = tf.placeholder("float", [None, n_steps, n_input])
x1 = tf.unstack(x, n_steps, 1)
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
静态 rnn 的意思就是在图中创建一个固定长度(n_steps)的网络。这将导致
缺点:
- 生成过程耗时更长,占内存更多,导出的模型更大;
- 无法传递比最初指定的更长的序列(> n_steps)。
优点:
模型中带有某个序列中间台的信息,便与调试。
2、动态接口:dynamic_rnn
主要使用 tf.nn.dynamic_rnn
x = tf.placeholder("float", [None, n_steps, n_input])
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs,_  = tf.nn.dynamic_rnn(lstm_cell ,x,dtype=tf.float32)
outputs = tf.transpose(outputs, [1, 0, 2])
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
动态的tf.nn.dynamic_rnn被执行时,它使用循环来动态构建图形。这意味着
优点:
- 图形创建速度更快,占用内存更少;
- 并且可以提供可变大小的批处理。
缺点:
- 模型中只有最后的状态。
动态rnn的意思是只创建样本中的一个序列RNN,其他序列数据会通过循环进入该RNN运算
区别:
1、输入输出不同:
dynamic_rnn实现的功能就是可以让不同迭代传入的batch可以是长度不同数据,但同一次迭代一个batch内部的所有数据长度仍然是固定的。例如,第一时刻传入的数据shape=[batch_size, 10],第二时刻传入的数据shape=[batch_size, 12],第三时刻传入的数据shape=[batch_size, 8]等等。
但是static_rnn不能这样,它要求每一时刻传入的batch数据的[batch_size, max_seq],在每次迭代过程中都保持不变。
2、训练方式不同:
具体参见参考文献1
多层LSTM的代码实现对比:
1、静态多层RNN
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True) n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类)
batch_size = 128 tf.reset_default_graph()
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) gru = tf.contrib.rnn.GRUCell(n_hidden*2)
lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden)
mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru]) x1 = tf.unstack(x, n_steps, 1)
outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) learning_rate = 0.001
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) training_iters = 100000 display_step = 10 # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!") # 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
2、动态多层RNN
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("c:/user/administrator/data/", one_hot=True) n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类)
batch_size = 128 tf.reset_default_graph()
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) gru = tf.contrib.rnn.GRUCell(n_hidden*2)
lstm_cell = tf.contrib.rnn.LSTMCell(n_hidden)
mcell = tf.contrib.rnn.MultiRNNCell([lstm_cell,gru]) outputs,states = tf.nn.dynamic_rnn(mcell,x,dtype=tf.float32)#(?, 28, 256)
outputs = tf.transpose(outputs, [1, 0, 2])#(28, ?, 256) 28个时序,取最后一个时序outputs[-1]=(?,256) pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) learning_rate = 0.001
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) training_iters = 100000 display_step = 10 # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!") # 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
【参考文献】:
1、https://www.jianshu.com/p/1b1ea45fab47
2、What's the difference between tensorflow dynamic_rnn and rnn?
------------------------------------------------------------------------
【转载】 【TensorFlow】static_rnn 和dynamic_rnn的区别的更多相关文章
- 【转载】      LSTM构建步骤以及static_rnn与dynamic_rnn之间的区别
		原文地址: https://blog.csdn.net/qq_23981335/article/details/89097757 --------------------- 作者:周卫林 来源:CSD ... 
- 转载:Ajax及 GET、POST 区别
		转载:Ajax及 GET.POST 区别 收获: xhr.setRequestHeader(), xhr.getResponseHeader() 可以设置和获取请求头/响应头信息; new FormD ... 
- [转载]java int与integer的区别
		声明: 本篇文章属于转载文章,来源: 
- 【转载】new和malloc的区别
		本篇随笔为转载,原贴地址:C++中new和malloc的十点区别. 前言 几个星期前去面试C++研发的实习岗位,面试官问了个问题: new与malloc有什么区别? 这是个老生常谈的问题.当时我回答n ... 
- 【转载】gcc和g++的区别
		[说明]本文转载自 静心 的文章 http://blog.163.com/lu_jun520/blog/static/5699613420116205148239/ 一般linux系统都自带了gcc编 ... 
- Java_类和对象(完美总结)_转载_覆盖和隐藏的区别,覆盖就不能使用了,而隐藏提供全局方法名或者全局变量名还可以使用
		转载自海子:http://www.cnblogs.com/dolphin0520/p/3803432.html Java:类与继承 对于面向对象的程序设计语言来说,类毫无疑问是其最重要的基础.抽象.封 ... 
- [转载] Rss 与 Feed 的概念区别
		转载自http://www.chinaz.com/news/2011/0831/207961.shtml 可能很多刚刚接触博客的童鞋们,也和我一样不太了解:rss和feed概念或者说不了解rss和fe ... 
- [转载]Tensorflow 的reduce_sum()函数的axis,keep_dim这些参数到底是什么意思?
		转载链接:https://www.zhihu.com/question/51325408/answer/125426642来源:知乎 这个问题无外乎有三个难点: 什么是sum 什么是reduce 什么 ... 
- tensorflow 笔记12:函数区别:placeholder,variable,get_variable,参数共享
		一.函数意义: 1.tf.Variable() 变量 W = tf.Variable(<initial-value>, name=<optional-name>) 用于生成一个 ... 
随机推荐
- nodejs 删除空文件
			var fs = require("fs") var path = require("path") var listRealPath = path.resolv ... 
- Activity知识点详解
			Activity知识点详解 一.什么是Activity 官方解释: The Activity class is a crucial component of an Android app, and t ... 
- OpenStack核心组件-keystone
			1. Keystone介绍 keystone是OpenStack的组件之一,用于为OpenStack家族中的其它组件成员提供统一的认证服务,包括身份验证.令牌的发放和校验.服务列表.用户权限的定义等等 ... 
- spring cloud (六) 将一个普通的springcloud项目 非feign或ribbon项目,改造成turbine可聚合监听的项目
			改造之前一个项目 service-a 1 pom.xml添加如下 <dependency> <groupId>org.springframework.cloud</gro ... 
- python测试开发django-rest-framework-64.序列化(serializers.Serializer)
			前言 REST framework中的serializers与Django的Form和ModelForm类非常像.我们提供了一个Serializer类,它为你提供了强大的通用方法来控制响应的输出, 以 ... 
- 解读 v8 排序源码
			前言 v8 是 Chrome 的 JavaScript 引擎,其中关于数组的排序完全采用了 JavaScript 实现. 排序采用的算法跟数组的长度有关,当数组长度小于等于 10 时,采用插入排序,大 ... 
- stm32flash的读写特性
			在使用stm32自带的flash保存数据时候,如下特点必须知道: 1.必须是先擦除一个扇区,才能写入 2.读数据没有限制 3.写数据必须是2字节,同时写入地址以一定要考虑字节对齐, 4.一般都是在最后 ... 
- TreeMap 的简单解释
			TreeMap的构造函数 可以传入 自定义的比较器.Map.SortedMap. put方法: public V put(K key, V value) { Entry<K,V> ... 
- janusgraph-图数据库的学习(2)
			janusgraph的简单使用 当安装好以后简单的使用janusgraph 1.进入janusgraph的shell命令界面 [root@had214 janusgraph-0.3.1-hadoop2 ... 
- B/S开发——文件夹的上传和下载
			本人在2010年时使用swfupload为核心进行文件的批量上传的解决方案.见文章:WEB版一次选择多个文件进行批量上传(swfupload)的解决方案. 本人在2013年时使用plupload为核心 ... 
