tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别。
而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时间轴上慢慢展开,有点类似我们大脑认识事物时会有相关的短期记忆。
这次我们使用RNN来识别手写数字。
首先导入数据并定义各种RNN的参数:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类
接着定义输入、输出以及各权重的形状:
# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
# 对weights和biases初始值定义
weights = {
# shape(28, 128)
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
# shape(128 , 10)
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
# shape(128, )
'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
# shape(10, )
'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
定义 RNN 的主体结构
最主要的就是定义RNN的主体结构。
def RNN(X, weights, biases):
# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)
X = tf.reshape(X, [-1, n_inputs])
# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)
X_in = tf.matmul(X, weights['in']) + biases['in']
# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])
# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
# 初始化全0state
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True
# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
# 计算结果,其中states[1]为分线state,也就是最后一个输出值
results = tf.matmul(states[1], weights['out']) + biases['out']
return results
训练RNN
定义好了 RNN 主体结构后, 我们就可以来计算 cost 和 train_op:
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)
训练时, 不断输出 accuracy, 观看结果:
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < training_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))
step += 1
最终 accuracy 的结果如下:
E:\Python\Python36\python.exe E:/learn/numpy/lesson3/main.py
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
2018-02-20 20:30:52.769108: I C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\36\tensorflow\core\platform\cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX
0.09375
0.710938
0.8125
0.789063
0.820313
0.882813
0.828125
0.867188
0.921875
0.90625
0.921875
0.890625
0.898438
0.945313
0.914063
0.945313
0.929688
0.96875
0.96875
0.929688
0.953125
0.945313
0.960938
0.992188
0.953125
0.9375
0.929688
0.96875
0.960938
0.945313
完整代码
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类
# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
# 对weights和biases初始值定义
weights = {
# shape(28, 128)
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
# shape(128 , 10)
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
# shape(128, )
'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
# shape(10, )
'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
def RNN(X, weights, biases):
# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)
X = tf.reshape(X, [-1, n_inputs])
# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)
X_in = tf.matmul(X, weights['in']) + biases['in']
# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])
# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
# 初始化全0state
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True
# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
# 计算结果,其中states[1]为分线state,也就是最后一个输出值
results = tf.matmul(states[1], weights['out']) + biases['out']
return results
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < training_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))
step += 1
tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】的更多相关文章
- tensorflow分类-【老鱼学tensorflow】
前面我们学习过回归问题,比如对于房价的预测,因为其预测值是个连续的值,因此属于回归问题. 但还有一类问题属于分类的问题,比如我们根据一张图片来辨别它是一只猫还是一只狗.某篇文章的内容是属于体育新闻还是 ...
- tensorflow用dropout解决over fitting-【老鱼学tensorflow】
在机器学习中可能会存在过拟合的问题,表现为在训练集上表现很好,但在测试集中表现不如训练集中的那么好. 图中黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训 ...
- tensorflow卷积神经网络-【老鱼学tensorflow】
前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...
- tensorflow例子-【老鱼学tensorflow】
本节主要用一个例子来讲述一下基本的tensorflow用法. 在这个例子中,我们首先伪造一些线性数据点,其实这些数据中本身就隐藏了一些规律,但我们假装不知道是什么规律,然后想通过神经网络来揭示这个规律 ...
- tensorflow Tensorboard可视化-【老鱼学tensorflow】
tensorflow自带了可视化的工具:Tensorboard.有了这个可视化工具,可以让我们在调整各项参数时有了可视化的依据. 本次我们先用Tensorboard来可视化Tensorflow的结构. ...
- tensorflow安装-【老鱼学tensorflow】
TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,Tensor ...
- tensorflow 传入值-【老鱼学tensorflow】
上个文章中讲述了tensorflow中如何定义变量以及如何读取变量的方式,本节主要讲述关于传入值. 变量主要用于在tensorflow系统中经常会被改变的值,而对于传入值,它只是当tensorflow ...
- tensorflow建造神经网络-【老鱼学tensorflow】
上次我们添加了一个add_layer函数,这次就要创建一个神经网络来预测/拟合相应的数据. 下面我们先来创建一下虚拟的数据,这个数据为二次曲线数据,但同时增加了一些噪点,其图像为: 相应的创建这些伪造 ...
- tensorflow Tensorboard2-【老鱼学tensorflow】
前面我们用Tensorboard显示了tensorflow的程序结构,本节主要用Tensorboard显示各个参数值的变化以及损失函数的值的变化. 这里的核心函数有: histogram 例如: tf ...
随机推荐
- 【洛谷P1963】变换序列
题目大意:对于一个顺序序列,求一个合法置换,可以满足一些约束,若存在多个合法置换,则输出字典序最小的一个置换. 题解:对于序列的置换是否有解的问题,可以和二分图的完美匹配相关联.由于是字典序最小,显然 ...
- C++:普通变量C++命名规则
C++提倡使用拥有一定意义的变量名,使程序代码更有阅读性,命名是必须使用的几种简单的C++命名规则: 命名时只能使用:字母字符.数字和下划线(_); 第一个字符不能是数字: 区分大小写(C++对大小写 ...
- I/O模型系列之三:IO通信模型BIO NIO AIO
一.传统的BIO 网络编程的基本模型是Client/Server模型,也就是两个进程之间进行相互通信,其中服务端提供位置信息(绑定的IP地址和监听端口),客户端通过连接操作向服务端监听的地址发起连接请 ...
- Dynamics CRM 日常使用JS整理(三)
一.指定 Partylist 类型字段能 lookup 的实体(以 Appointment 中某个字段为例子): var control = Xrm.Page.getControl("req ...
- python--协程之特别篇
Python通过yield提供了对协程的基本支持,但是不完全.而第三方的gevent为Python提供了比较完善的协程支持. gevent是第三方库,通过greenlet实现协程,其基本思想是: 当一 ...
- 微信小程序版本自动更新弹窗提示
代码如下: onLaunch () { if (wx.canIUse('getUpdateManager')) { const updateManager = wx.getUpdateManager( ...
- mysqldump 备份数据和恢复
命令行下具体用法如下: mysqldump -u用戶名 -p密码 -d 数据库名 表名 > 脚本名; 一.导出数据: 导出整个数据库结构和数据mysqldump -h localhost -u ...
- addEventListener解决多个window.onscroll共存的2个方法
方法1.(注意第一个和第二个的先后次序) window.onscroll=function(){console.log('第一个');} var oldMethod = window.onscroll ...
- docker保存日志文件到本地
其实很简单 docker logs +你需要添加的额外参数 + 容器id >文件名称 然后查看这个文件就可以了,也可以通过ftp协议下载到本地
- 【转】一文掌握 Linux 性能分析之内存篇
[转]一文掌握 Linux 性能分析之内存篇 前面我们已经学习了 CPU 篇,这篇来看下内存篇. 01 内存信息 同样在分析内存之前,我们得知到怎么查看系统内存信息,有以下几种方法. 1.1 /pro ...