RNN介绍

  在读本文之前,读者应该对全连接神经网络(Fully Connected Neural Network, FCNN)和卷积神经网络( Convolutional Neural Network, CNN)有一定的了解。对于FCNN和CNN来说,他们能解决很多实际问题,但是它们都只能单独的取处理一个个的输入,前一个输入和后一个输入是完全没有关系的 。而在现实生活中,我们输入的向量往往存在着前后联系,即前一个输入和后一个输入是有关联的,比如文本,语音,视频等,因此,我们需要了解深度学习中的另一类重要的神经网络,那就是循环神经网络(Recurrent Neural Network,RNN).

  循环神经网络(Recurrent Neural Network,RNN)依赖于一个重要的概念:序列(Sequence),即输入的向量是一个序列,存在着前后联系。简单RNN的结构示意图如下:



相比于之前的FCNN,RNN的结构中多出了一个自循环部分,即W所在的圆圈,这是RNN的精华所在,它展开后的结构如下:



对于t时刻的输出向量\(o_{t}\),它的输出不仅仅依赖于t时刻的输入向量\(x_{t}\),还依赖于t-1时刻的隐藏层向量\(s_{t-1}\),以下是输出向量\(o_{t}\)的计算公式:

\[s_{t}=f(Ux_{t}+Ws_{t-1})
\]

\[o_{t}=g(Vs_{t})
\]

其中,第二个式子为输出层的计算公式,输出层为全连接层,V为权重矩阵,g为激活函数。第一个式子中,U是输入x的权重矩阵,W是上一次隐藏层值s的输入权重矩阵,f为激活函数。注意到,RNN的所有权重矩阵U,V,W是共享的,这样可以减少计算量。

  本文将会用TensorFlow中已经帮我们实现好的RNN基本函数tf.contrib.rnn.BasicRNNCell(), tf.nn.dynamic_rnn()来实现简单RNN,并且用该RNN来识别MNIST数据集。

MNIST数据集

  MNIST数据集是深度学习的经典入门demo,它是由6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小(如下图),而且都是黑白色构成(这里的黑色是一个0-1的浮点数,黑色越深表示数值越靠近1),这些图片是采集的不同的人手写从0到9的数字。



  在TensorFlow中,已经内嵌了MNIST数据集,笔者已经下载下来了,如下:



  接下来本文将要用MNIST数据集作为RNN应用的一个demo.

RNN大战MNIST数据集

  用CNN来识别MNIST数据集,我们好理解,这是利用了图片的空间信息。可是,RNN要求输入的向量是序列,那么,如何把图片看成是序列呢?

  图片的大小为28*28,我们把每一列向量看成是某一时刻的向量,那么每张图片就是一个序列,里面含有28个向量,每个向量含有28个元素,如下:



  下面给出如何利用TensorFlow来搭建简单RNN,用来识别MNIST数据集,完整的Python代码如下:

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 获取MNIST数据
mnist = input_data.read_data_sets(r"./MNIST_data", one_hot=True) # 设置RNN结构
element_size = 28
time_steps = 28
num_classes = 10
batch_size = 128
hidden_layer_size = 150 # 输入向量和输出向量
_inputs = tf.placeholder(tf.float32, shape=[None, time_steps, element_size], name='inputs')
y = tf.placeholder(tf.float32, shape=[None, num_classes], name='inputs') # 利用TensorFlow的内置函数BasicRNNCell, dynamic_rnn来构建RNN的基本模块
rnn_cell = tf.contrib.rnn.BasicRNNCell(hidden_layer_size)
outputs, _ = tf.nn.dynamic_rnn(rnn_cell, _inputs, dtype=tf.float32)
Wl = tf.Variable(tf.truncated_normal([hidden_layer_size, num_classes], mean=0,stddev=.01))
bl = tf.Variable(tf.truncated_normal([num_classes],mean=0,stddev=.01)) def get_linear_layer(vector):
return tf.matmul(vector, Wl) + bl # 取输出的向量outputs中的最后一个向量最为最终输出
last_rnn_output = outputs[:,-1,:]
final_output = get_linear_layer(last_rnn_output) # 定义损失函数并用RMSPropOptimizer优化
softmax = tf.nn.softmax_cross_entropy_with_logits(logits=final_output, labels=y)
cross_entropy = tf.reduce_mean(softmax)
train_step = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cross_entropy) # 统计准确率
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(final_output,1))
accuracy = (tf.reduce_mean(tf.cast(correct_prediction, tf.float32)))*100 sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# 测试集
test_data = mnist.test.images[:batch_size].reshape((-1, time_steps, element_size))
test_label = mnist.test.labels[:batch_size] # 每次训练batch_size张图片,一共训练3000次
for i in range(3001):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, time_steps, element_size))
sess.run(train_step, feed_dict={_inputs:batch_x, y:batch_y})
if i % 100 == 0:
loss = sess.run(cross_entropy, feed_dict={_inputs: batch_x, y: batch_y})
acc = sess.run(accuracy, feed_dict={_inputs:batch_x, y: batch_y})
print ("Iter " + str(i) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc)) # 在测试集上的准确率
print("Testing Accuracy:", sess.run(accuracy, feed_dict={_inputs:test_data, y:test_label}))

  运行上述代码,输出的结果如下:

Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz
Iter 0, Minibatch Loss= 2.301171, Training Accuracy= 11.71875
Iter 100, Minibatch Loss= 1.718483, Training Accuracy= 47.65625
Iter 200, Minibatch Loss= 0.862968, Training Accuracy= 71.09375
Iter 300, Minibatch Loss= 0.513068, Training Accuracy= 86.71875
Iter 400, Minibatch Loss= 0.570475, Training Accuracy= 83.59375
Iter 500, Minibatch Loss= 0.254566, Training Accuracy= 92.96875
Iter 600, Minibatch Loss= 0.457989, Training Accuracy= 85.93750
Iter 700, Minibatch Loss= 0.151181, Training Accuracy= 96.87500
Iter 800, Minibatch Loss= 0.171168, Training Accuracy= 94.53125
Iter 900, Minibatch Loss= 0.142494, Training Accuracy= 94.53125
Iter 1000, Minibatch Loss= 0.155114, Training Accuracy= 97.65625
Iter 1100, Minibatch Loss= 0.096007, Training Accuracy= 96.87500
Iter 1200, Minibatch Loss= 0.341476, Training Accuracy= 88.28125
Iter 1300, Minibatch Loss= 0.133509, Training Accuracy= 96.87500
Iter 1400, Minibatch Loss= 0.076408, Training Accuracy= 98.43750
Iter 1500, Minibatch Loss= 0.122228, Training Accuracy= 98.43750
Iter 1600, Minibatch Loss= 0.099382, Training Accuracy= 96.87500
Iter 1700, Minibatch Loss= 0.084686, Training Accuracy= 97.65625
Iter 1800, Minibatch Loss= 0.067009, Training Accuracy= 98.43750
Iter 1900, Minibatch Loss= 0.189703, Training Accuracy= 94.53125
Iter 2000, Minibatch Loss= 0.116077, Training Accuracy= 96.09375
Iter 2100, Minibatch Loss= 0.028867, Training Accuracy= 100.00000
Iter 2200, Minibatch Loss= 0.064198, Training Accuracy= 99.21875
Iter 2300, Minibatch Loss= 0.078259, Training Accuracy= 97.65625
Iter 2400, Minibatch Loss= 0.106613, Training Accuracy= 97.65625
Iter 2500, Minibatch Loss= 0.078722, Training Accuracy= 98.43750
Iter 2600, Minibatch Loss= 0.045871, Training Accuracy= 98.43750
Iter 2700, Minibatch Loss= 0.030953, Training Accuracy= 99.21875
Iter 2800, Minibatch Loss= 0.062823, Training Accuracy= 96.87500
Iter 2900, Minibatch Loss= 0.040367, Training Accuracy= 99.21875
Iter 3000, Minibatch Loss= 0.017787, Training Accuracy= 100.00000
Testing Accuracy: 97.6563

可以看到,用简单RNN来识别MNIST数据集,也能取得很好的效果!

  本次分享到此结束,欢迎大家交流~

注意:本人现已开通微信公众号: 轻松学会Python爬虫(微信号为:easy_web_scrape), 欢迎大家关注哦~~

RNN入门(一)识别MNIST数据集的更多相关文章

  1. Python实现bp神经网络识别MNIST数据集

    title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...

  2. SGD与Adam识别MNIST数据集

    几种常见的优化函数比较:https://blog.csdn.net/w113691/article/details/82631097 ''' 基于Adam识别MNIST数据集 ''' import t ...

  3. 卷积神经网络CNN识别MNIST数据集

    这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容. 程序的开头是导入TensorFlow: impor ...

  4. mxnet实战系列(一)入门与跑mnist数据集

    最近在摸mxnet和tensorflow.两个我都搭起来了.tensorflow跑了不少代码,总的来说用得比较顺畅,文档很丰富,api熟悉熟悉写代码没什么问题. 今天把两个平台做了一下对比.同是跑mn ...

  5. 81、Tensorflow实现LeNet-5模型,多层卷积层,识别mnist数据集

    ''' Created on 2017年4月22日 @author: weizhen ''' import os import tensorflow as tf import numpy as np ...

  6. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  7. 关于入门深度学习mnist数据集前向计算的记录

    import osimport lr as lrimport tensorflow as tffrom pyspark.sql.functions import stddevfrom tensorfl ...

  8. 吴裕雄 python 神经网络——TensorFlow实现AlexNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf # 输入数据 from tensorflow.examples.tutorials.mnist import input_data mnist = in ...

  9. 机器学习(2) - KNN识别MNIST

    代码 https://github.com/s055523/MNISTTensorFlowSharp 数据的获得 数据可以由http://yann.lecun.com/exdb/mnist/下载.之后 ...

随机推荐

  1. Mac 下 python 环境问题

    一.Mac下,可能存在的 python 环境: 1.Mac系统自带的python环境在(由于不同的 mac 系统,默认自带的 python 版本可能不一样): Python 2.7.10: /Syst ...

  2. Python环境安装及IDE介绍

    因为最近时间比较松散,公司的业务也不多,所以想趁机赶紧投入到人工智能的学习大业当中.经过多次比较,看到目前市面上还是使用Python做为基础语言较多,进儿学习算法.人工智能组件.机器学习.数据挖掘等课 ...

  3. 在CentOS 7上安装和使用GlusterFS

    GlusterFS aggregates various storage servers over Ethernet or Infiniband RDMA interconnect into one ...

  4. Python基础理论 - 函数

    函数是第一类对象:可以当做数据来传 1.  可以被引用 2.  可以作为函数参数 3.  可以作为函数返回值 4.  可以作为容器类型的元素 小例子: def func1(): print('func ...

  5. 【腾讯Bugly干货分享】职场中脱颖而出的成长秘诀

    本文来自于腾讯Bugly公众号(weixinBugly),未经作者同意,请勿转载,原文地址:http://mp.weixin.qq.com/s/uQKpVg7HMLfogGzzMyc9iQ 导语 时光 ...

  6. Python面向对象1:类与对象

    Python的面向对象- 面向对象编程 - 基础 - 公有私有 - 继承 - 组合,Mixin- 魔法函数 - 魔法函数概述 - 构造类魔法函数 - 运算类魔法函数 # 1. 面向对象概述(Objec ...

  7. 博客Hexo + github pages + 阿里云绑定域名搭建个人博客

    申请域名 万网购买的域名,地址:https://wanwang.aliyun.com/domain/com?spm=5176.8142029.388261.137.LoKzy7 控制台进行解析 控制台 ...

  8. 深度学习环境配置:Ubuntu16.04安装GTX1080Ti+CUDA9.0+cuDNN7.0完整安装教程(多链接多参考文章)

    本来就对Linux不熟悉,经过几天惨痛的教训,参考了不知道多少篇文章,终于把环境装好了,每篇文章或多或少都有一些用,但没有一篇完整的能解决我安装过程碰到的问题,所以决定还是自己写一篇我安装过程的教程, ...

  9. aaa配置(第十三组)

    拓扑 网络情况 A PING B A PING C PC-B PING PC-C 2.R1的配置 a.console线 R1(config)#username admin1 password Admi ...

  10. H5在WebView上开发小结

    背景 来自我司业务方要求,需开发一款APP.但由于时间限制,只能采取套壳app方式,即原生app内嵌webview展示前端页面.本文主要记述JavaScript与原生app间通信,以及内嵌webvie ...