RNN静态与动态
静态、多层RNN:
import numpy as np
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True) #网络模型参数
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类) #训练参数
batch_size = 128
learning_rate = 0.001
training_iters = 10000
display_step = 10 # tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) #构建网络
stacked_rnn = []
for _ in range(3):
stacked_rnn.append(tf.contrib.rnn.LSTMCell(n_hidden))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn) x1=tf.unstack(x,n_steps,1)#在axis=1进行解包分解。 outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)#inputs must be a sequence
#最后一层全连接 outputs[-1]
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) # Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!")
# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
在学习RNN这一章的时候,遇到static_rnn中输入数据 x 的格式:
[None, n_steps, n_input] 进行变换→ x1=tf.unstack(x,n_steps,1)
之后再传入:outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)
很难理解,为什么要这样做,数据又进行了怎样的变换。
以下,为stack和unstack的详细举例:
- tf.stack(values, axis=0, name=’stack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将一组张量以指定的轴,提高一个维度。
假设要转变的张量数组values的长度为N,其中的每个张量的形状为(A, B, C)。
如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。
举例如下:
‘x’ is [1, 4], 形状是(2),维度为1
‘y’ is [2, 5], 形状是(2),维度为1
‘z’ is [3, 6], 形状是(2),维度为1
stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # axis的值默认为0。输出的形状为(3, 2)
stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] # axis的值为1。输出的形状为(2, 3)
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)
axis可这样理解:stack就是要将一组相同形状的张量提高一个维度。axis就是这些张量里,将axis指定的维度用所有这些张量数组代替。如axis=2,表示指定在第2个维度,原来的元素用整个张量数组里的元素代替,即从(A, B, C)转变为(A, B, N, C)
参数:
values: 一个有相同形状与数据类型的张量数组。
axis: 以轴axis为中心来转变的整数。默认是第一个维度即axis=0。支持负数。取值范围为[-(R+1), R+1)
name: 这个操作的名字(可选)
返回:被提高一个维度后的张量
异常: ValueError: 如果轴axis超出范围[-(R+1), R+1).
- tf.unstack()
tf.unstack(value, num=None, axis=0, name=’unstack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R-1的张量。即将一组张量以指定的轴,减少一个维度。正好和stack()相反。
将张量value分割成num个张量数组。如果num没有指定,则是根据张量value的形状来指定。如果value.shape[axis]不存在,则抛出ValueError的异常。
假如一个张量的形状是(A, B, C, D)。
如果axis == 0,则输出的张量是value[i, :, :, :],i取值为[0,A),每个输出的张量的形状为(B,C,D)。
如果axis == 1,则输出的张量是value[:, i, :, :],i取值为[0,B),每个输出的张量的形状为(A,C,D)。
如果axis == 2,则输出的张量是value[:, :, i, :],i取值为[0,C),每个输出的张量的形状为(A,B,D)。依次类推。
举例如下:
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]] # 形状是(3,4),维度为2
unstack(x,axis=0) =>以指定的维度0为轴,转变成3个形状为(4)张量[1,1,1,1],[2,2,2,2],[3,3,3,3]
unstack(x,axis=1) =>以指定的维度1为轴,转变成4个形状为(3)张量[1,2,3],[1,2,3],[1,2,4],[1,2,3]
axis可这样理解:unstack就是要将一个张量降低为低一个维度的张量数组。axis就是将axis指定的维度,用所有这个张量里同维度的数据代替。
参数:
value: 一个将要被降维的维度大于0的张量。
num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
axis: 整数.以轴axis指定的维度来转变 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
name: 这个操作的名字(可选)
返回:
从张量value降维后的张量数组。
异常:
ValueError: 如果num没有指定并且无法求出来。
ValueError: 如果axis超出范围 [-R, R)。
经过下面的例子理解后,上面的1对应axis=1, nsteps对应函数中的num参数,表示axis=1的长度。该操作将数据 x 按照序列数目切开。我们传入的 x 是个3维tensor,将其按照序列数切开,得到了n_steps个 二维的tensor, [batchsize, n_input]
RNN静态与动态的更多相关文章
- Android中BroadcastReceiver的两种注册方式(静态和动态)详解
今天我们一起来探讨下安卓中BroadcastReceiver组件以及详细分析下它的两种注册方式. BroadcastReceiver也就是"广播接收者"的意思,顾名思义,它就是用来 ...
- 生成lua的静态库.动态库.lua.exe和luac.exe
前些日子准备学习下关于lua coroutine更为强大的功能,然而发现根据lua 5.1.4版本来运行一段代码的话也会导致 "lua: attempt to yield across me ...
- Delphi DLL的创建、静态及动态调用
转载:http://blog.csdn.net/welcome000yy/article/details/7905463 结合这篇博客:http://www.cnblogs.com/xumenger/ ...
- 3D touch 静态、动态设置及进入APP的跳转方式
申明Quick Action有两种方式:静态和动态 静态是在info.plist文件中申明,动态则是在代码中注册,系统支持两者同时存在. -系统限制每个app最多显示4个快捷图标,包括静态和动态 静态 ...
- C/C++ 跨平台交叉编译、静态库/动态库编译、MinGW、Cygwin、CodeBlocks使用原理及链接参数选项
目录 . 引言 . 交叉编译 . Cygwin简介 . 静态库编译及使用 . 动态库编译及使用 . MinGW简介 . CodeBlocks简介 0. 引言 UNIX是一个注册商标,是要满足一大堆条件 ...
- RT-Thread创建静态、动态线程
RT-Thread 实时操作系统核心是一个高效的硬实时核心,它具备非常优异的实时性.稳定性.可剪裁性,当进行最小配置时,内核体积可以到 3k ROM 占用. 1k RAM 占用. RT-Thread ...
- linux静态与动态库创建及使用实例
一,gcc基础语法: 基本语法结构:(由以下四部分组成) gcc -o 可执行文件名 依赖文件集(*.c/*.o) 依赖库文件及其头文件集(由-I或-L与-l指明) gcc 依赖文件集(*.c/*.o ...
- MYSQL学习笔记2--mysql 静态和动态plugin
mysql源码编译 .cmke 安装 yum install cmake .依赖的库下载机安装: yum -y install gcc* gcc-c++* autoconf* automake* zl ...
- Android SurfaceView实现静态于动态画图效果
本文是基于Android的SurfaceView的动态画图效果,实现静态和动态下的正弦波画图,可作为自己做图的简单参考,废话不多说,先上图, 静态效果: 动态效果: 比较简单,代码注释的也比较详细,易 ...
随机推荐
- eclipse导入工程报Invalid project description(转载)
转自:http://blog.sina.com.cn/s/blog_a2eab3000101k3r7.html 昨天新搭建的环境,今天把以前的项目导入eclipse时报错: 说的是我导入的项目与wor ...
- lightoj 1125【01背包变性】
题意: 从n个数里选出m个来,还要使得这m个数之和被d整除. 给一个n和q,再给n个数,再给q个询问,每个询问包含两个数,d,m; 对于每个case输出每个q个询问的可行的方案数. 思路: 每个数只能 ...
- poj1724【最短路】
题意: 给出n个城市,然后给出m条单向路,给出了每条路的距离和花费,问一个人有k coins,在不超过money的情况下从1到n最短路径路径. 思路: 我相信很多人在上面那道题的影响下,肯定会想想,在 ...
- PTA【复数相乘】
队友在比赛时A掉的.吊吊吊!!! 主要考虑这些情况吧||| 案例: /* 3i i -3 i -1-i 1+i i 1 -1-i -1-i */ -3 -3i -2i i 2i #include &l ...
- SQL_MODE 的设置
查看当前的 SQL_MODE SELECT @@sql_mode SELECT @@sql_mode 的执行结果 mysql> SELECT @@sql_mode; +------------- ...
- 7天学完Java基础之6/7
final关键字的概念与四种用法 final关键字代表最终,不可改变的 常见四种用法: 可以用来修饰一个类 当前这个类不能有任何子类 可以用来修饰一个方法 当final关键字用来修饰一个方法的时候,这 ...
- Windows、Linux、Android常用软件分享
Windows.Linux.Android常用软件分享 前言 本来没准备写这篇博客,一是没时间,还有其他很多优先级更高的事情要做.二是写这种博客对我自己来说没什么的帮助,以前我就想好了不写教程类,使用 ...
- SQL SUM函数内使用CASE函数
- 实例 - 在这个表里进行查询: 查询出如下结果(统计每天的输赢次数): - 开始查询 - 首先创建测试表: CREATE TABLE info( date ), result ) ); 插入测试数 ...
- Hdu 5452 Minimum Cut (2015 ACM/ICPC Asia Regional Shenyang Online) dfs + LCA
题目链接: Hdu 5452 Minimum Cut 题目描述: 有一棵生成树,有n个点,给出m-n+1条边,截断一条生成树上的边后,再截断至少多少条边才能使图不连通, 问截断总边数? 解题思路: 因 ...
- 洛谷 P3327 [SDOI2015]约数个数和 || Number Challenge Codeforces - 235E
https://www.luogu.org/problemnew/show/P3327 不会做. 去搜题解...为什么题解都用了一个奇怪的公式?太奇怪了啊... 公式是这样的: $d(xy)=\sum ...