本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。

初始化

首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:

learning_rate = 1e-
num_units =
num_layer =
input_size =
time_step =
total_steps =
category_num =
steps_per_validate =
steps_per_test =
batch_size = tf.placeholder(tf.int32, [])
keep_prob = tf.placeholder(tf.float32, [])

然后还需要声明一下 MNIST 数据生成器:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:

x = tf.placeholder(tf.float32, [None, ])
y_label = tf.placeholder(tf.float32, [None, ])

这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。

接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:

x_shape = tf.reshape(x, [-, time_step, input_size])

RNN 层

接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。

所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:

def cell(num_units):
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
    return DropoutWrapper(cell, output_keep_prob=keep_prob)

这里还加入了 Dropout,来减少训练过程中的过拟合。

接下来我们再利用它来构建多层的 RNN:

cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])

注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。

接下来我们需要声明一个初始状态:

h0 = cells.zero_state(batch_size, dtype=tf.float32)

然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:

output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。

这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:

output = output[:, -, :]

或者直接取隐藏状态最后一层的 h 也是相同的:

h = hs[-].h

在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。

输出层

接下来我们再做一次线性变换和 Softmax 输出结果即可:

# Output Layer
w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b
# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。

训练和评估

最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:

# Train
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=), tf.argmax(y_label, axis=))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

# Train
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ):
        batch_x, batch_y = mnist.train.next_batch()
        sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: ]})
        # Train Accuracy
        :
            print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,
                                                               batch_size: batch_x.shape[]}))
        # Test Accuracy
        :
            test_x, test_y = mnist.test.images, mnist.test.labels
            print('Test', step,
                  sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: , batch_size: test_x.shape[]}))

运行

直接运行之后,只训练了几轮就可以达到 98% 的准确率:

Train  0.27
Test  0.2223
Train  0.87
Train  0.91
Train  0.94
Train  0.94
Train  0.99
Test  0.9595
Train  0.95
Train  0.97
Train  0.98

可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。

芝麻HTTP:TensorFlow LSTM MNIST分类的更多相关文章

  1. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  2. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  3. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  4. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  5. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  6. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  7. TensorFlow LSTM 注意力机制图解

    TensorFlow LSTM Attention 机制图解 深度学习的最新趋势是注意力机制.在接受采访时,现任OpenAI研究主管的Ilya Sutskever提到,注意力机制是最令人兴奋的进步之一 ...

  8. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  9. Ubuntu16.04安装TensorFlow及Mnist训练

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...

随机推荐

  1. Vue.js搭建路由报错 router.map is not a function,Cannot read property ‘component’ of undefined

    错误: 解决办法: 2.0已经没有map了,使用npm install vue-router@0.7.13 命令兼容1.0版本vue 但是安装完之后会出现一个错误: Cannot read prope ...

  2. POJ 2653 Pick-up sticks [线段相交 迷之暴力]

    Pick-up sticks Time Limit: 3000MS   Memory Limit: 65536K Total Submissions: 12861   Accepted: 4847 D ...

  3. BZOJ 2329: [HNOI2011]括号修复 [splay 括号]

    题目描述 一个合法的括号序列是这样定义的: 空串是合法的. 如果字符串 S 是合法的,则(S)也是合法的. 如果字符串 A 和 B 是合法的,则 AB 也是合法的. 现在给你一个长度为 N 的由‘(' ...

  4. Phalcon调试大杀器之phalcon-debugbar安装

    Phalcon 是一款非常火的高性能C扩展php开发框架.特点是高性能低耦合,但遗憾的是长期缺少一款得力的调试辅助工具. 目前版本主要以Laravel debugbar的具有功能为蓝本开发,并针对ph ...

  5. LeetCode - 620. Not Boring Movies

    X city opened a new cinema, many people would like to go to this cinema. The cinema also gives out a ...

  6. 页面刷新方式实时检测cookie是否失效

    在浏览器端每隔10秒钟刷新一次页面,可用于检查cookie值是否失效. 在study.php文件中存在这样一条语句: <meta http-equiv="refresh" c ...

  7. iOS 8 UIAlertController 和 UIAlertAction

    将alertView 和 actionSheet 封装在UIAlertController 里面化整为零,使开发者更便利 当我们一味的追求高内聚,低耦合的时候,伟大的苹果反其道而行之,这也告诉了我们一 ...

  8. 洛谷 P2622 关灯问题II【状压DP;隐式图搜索】

    题目描述 现有n盏灯,以及m个按钮.每个按钮可以同时控制这n盏灯--按下了第i个按钮,对于所有的灯都有一个效果.按下i按钮对于第j盏灯,是下面3中效果之一:如果a[i][j]为1,那么当这盏灯开了的时 ...

  9. 获取对象属性类型、属性名称、属性值的研究:反射和JEXL解析引擎

    同步发布:http://www.yuanrengu.com/index.php/20170511.html 先简单介绍下反射的概念:java反射机制是在运行状态中,对于任意一个类,都能够知道这个类的所 ...

  10. explorer.exe 该文件没有与之关联的程序来执行该操作

    删了点右键的东西搞出来的问题 其实就是关联出错了,解决:(新建一个temp.reg,内容如下,然后双击导入注册表即可) Windows Registry Editor Version 5.00 [[H ...