一 RNN概述
    前面我们叙述了BP算法, CNN算法, 那么为什么还会有RNN呢?? 什么是RNN, 它到底有什么不同之处? RNN的主要应用领域有哪些呢?这些都是要讨论的问题.

1) BP算法,CNN之后, 为什么还有RNN?

细想BP算法,CNN(卷积神经网络)我们会发现, 他们的输出都是只考虑前一个输入的影响而不考虑其它时刻输入的影响, 比如简单的猫,狗,手写数字等单个物体的识别具有较好的效果. 但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了.因此, RNN就应运而生了.

2) 什么是RNN?

RNN是一种特殊的神经网络结构, 它是根据"人的认知是基于过往的经验和记忆"这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种'记忆'功能.

RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

3) RNN的主要应用领域有哪些呢?

RNN的应用领域有很多, 可以说只要考虑时间先后顺序的问题都可以使用RNN来解决.这里主要说一下几个常见的应用领域:

① 自然语言处理(NLP): 主要有视频处理, 文本生成, 语言模型, 图像处理

② 机器翻译, 机器写小说

③ 语音识别

④ 图像描述生成

⑤ 文本相似度计算

⑥ 音乐推荐、网易考拉商品推荐、Youtube视频推荐等新的应用领域.

二 RNN(循环神经网络)
    1) RNN模型结构
    前面我们说了RNN具有时间"记忆"的功能, 那么它是怎么实现所谓的"记忆"的呢?

图1 RNN结构图

如图1所示, 我们可以看到RNN层级结构较之于CNN来说比较简单, 它主要有输入层,Hidden Layer, 输出层组成.

并且会发现在Hidden Layer 有一个箭头表示数据的循环更新, 这个就是实现时间记忆功能的方法.

如果到这里你还是没有搞懂RNN到底是什么意思,那么请继续往下看!

图2 Hidden Layer的层级展开图

如图2所示为Hidden Layer的层级展开图. t-1, t, t+1表示时间序列. X表示输入的样本. St表示样本在时间t处的的记忆,St = f(W*St-1 +U*Xt). W表示输入的权重, U表示此刻输入的样本的权重, V表示输出的样本权重.

在t =1时刻, 一般初始化输入S0=0, 随机初始化W,U,V, 进行下面的公式计算:

其中,f和g均为激活函数. 其中f可以是tanh,relu,sigmoid等激活函数,g通常是softmax也可以是其他。

时间就向前推进,此时的状态s1作为时刻1的记忆状态将参与下一个时刻的预测活动,也就是:

以此类推, 可以得到最终的输出值为:

注意: 1. 这里的W,U,V在每个时刻都是相等的(权重共享).

2. 隐藏状态可以理解为:  S=f(现有的输入+过去记忆总结)

2) RNN的反向传播
    前面我们介绍了RNN的前向传播的方式, 那么RNN的权重参数W,U,V都是怎么更新的呢?

每一次的输出值Ot都会产生一个误差值Et, 则总的误差可以表示为:.

则损失函数可以使用交叉熵损失函数也可以使用平方误差损失函数.

由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.(不熟悉BP的可以参考这里)

也就是要求参数的梯度:

首先我们求解W的更新方法, 由前面的W的更新可以看出它是每个时刻的偏差的偏导数之和.

在这里我们以 t = 3时刻为例, 根据链式求导法则可以得到t = 3时刻的偏导数为:

此时, 根据公式我们会发现, S3除了和W有关之外, 还和前一时刻S2有关.

对于S3直接展开得到下面的式子:

对于S2直接展开得到下面的式子:

对于S1直接展开得到下面的式子:

将上述三个式子合并得到:

这样就得到了公式:

这里要说明的是:表示的是S3对W直接求导, 不考虑S2的影响.(也就是例如y = f(x)*g(x)对x求导一样)

其次是对U的更新方法. 由于参数U求解和W求解类似,这里就不在赘述了,最终得到的具体的公式如下:

最后,给出V的更新公式(V只和输出O有关):

三 RNN的一些改进算法
    前面我们介绍了RNN的算法, 它处理时间序列的问题的效果很好, 但是仍然存在着一些问题, 其中较为严重的是容易出现梯度消失或者梯度爆炸的问题(BP算法和长时间依赖造成的). 注意: 这里的梯度消失和BP的不一样,这里主要指由于时间过长而造成记忆值较小的现象.

因此, 就出现了一系列的改进的算法, 这里介绍主要的两种算法: LSTM 和 GRU.

LSTM 和 GRU对于梯度消失或者梯度爆炸的问题处理方法主要是:

对于梯度消失: 由于它们都有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题。

对于梯度爆炸:用来克服梯度爆炸的问题就是gradient clipping,也就是当你计算的梯度超过阈值c或者小于阈值-c的时候,便把此时的梯度设置成c或-c。

1) LSTM算法(Long Short Term Memory, 长短期记忆网络 ) --- 重要的目前使用最多的时间序列算法

图3 LSTM算法结构图

如图3为LSTM算法的结构图.

和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.

用于遗忘的门叫做"遗忘门", 用于信息增加的叫做"信息增加门",最后是用于输出的"输出门". 这里就不展开介绍了.

此外,LSTM算法的还有一些变种.

如图4所示, 它增加“peephole connections”层 , 让门层也接受细胞状态的输入.

图4 LSTM算法的一个变种

如图5所示为LSTM的另外一种变种算法.它是通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增加什么信息,而是一起进行考虑。

图5 LSTM算法的一个变种

2) GRU算法

GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.

其各个部分满足关系式如下:

四 基于Tensorflow的基本操作和总结
    使用tensorflow的基本操作如下:

# _*_coding:utf-8_*_

import tensorflow as tf
import numpy as np

'''
TensorFlow中的RNN的API主要包括以下两个路径:
1) tf.nn.rnn_cell(主要定义RNN的几种常见的cell)
2) tf.nn(RNN中的辅助操作)
'''
# 一 RNN中的cell
# 基类(最顶级的父类): tf.nn.rnn_cell.RNNCell()
# 最基础的RNN的实现: tf.nn.rnn_cell.BasicRNNCell()
# 简单的LSTM cell实现: tf.nn.rnn_cell.BasicLSTMCell()
# 最常用的LSTM实现: tf.nn.rnn_cell.LSTMCell()
# RGU cell实现: tf.nn.rnn_cell.GRUCell()
# 多层RNN结构网络的实现: tf.nn.rnn_cell.MultiRNNCell()

# 创建cell
# cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
# print(cell.state_size)
# print(cell.output_size)

# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
# inputs = tf.placeholder(dtype=tf.float32, shape=[4, 64])

# 给定RNN的初始状态
# s0 = cell.zero_state(4, tf.float32)
# print(s0.get_shape())

# 对于t=1时刻传入输入和state0,获取结果值
# output, s1 = cell.call(inputs, s0)
# print(output.get_shape())
# print(s1.get_shape())

# 定义LSTM cell
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
inputs = tf.placeholder(tf.float32, shape=[4, 48])
# 给定初始状态
s0 = lstm_cell.zero_state(4, tf.float32)
# 对于t=1时刻传入输入和state0,获取结果值
output, s1 = lstm_cell.call(inputs, s0)
print(output.get_shape())
print(s1.h.get_shape())
print(s1.c.get_shape())
    当然, 你可能会发现使用cell.call()每次只能调用一个得到一个状态, 如有多个状态需要多次重复调用较为麻烦, 那么我们怎么解决的呢? 可以参照后面的基于RNN的手写数字识别和单词预测的实例查找解决方法.

本文主要介绍了一种时间序列的RNN神经网络及其基础上衍生出来的变种算法LSTM和GRU算法, 也对RNN算法的使用场景作了介绍.

当然, 由于篇幅限制, 这里对于双向RNNs和多层的RNNs没有介绍. 另外, 对于LSTM的参数更新算法在这里也没有介绍, 后续补上吧!

最后, 如果你发现了任何问题, 欢迎一起探讨, 共同进步!!
---------------------

RNN概述-深度学习 -神经网络的更多相关文章

  1. 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD

    上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...

  2. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  5. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  6. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

  7. GitHub 上 57 款最流行的开源深度学习项目

    转载:https://www.oschina.net/news/79500/57-most-popular-deep-learning-project-at-github GitHub 上 57 款最 ...

  8. Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包

    Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包 用 CNTK 搞深度学习 (一) 入门 Computational Network Toolk ...

  9. Github上Stars最多的53个深度学习项目,TensorFlow遥遥领先

    原文:https://github.com/aymericdamien/TopDeepLearning 项目名称 Stars 项目介绍 TensorFlow 29622 使用数据流图计算可扩展机器学习 ...

随机推荐

  1. 课时53.video标签第二种格式(掌握)

    由于视频数据非常非常的重要,所以五大浏览器厂商都不愿意支持别人都视频格式,所以导致了没有一种视频格式是所有浏览器都支持的,这个时候W3C为了解决这个问题,所以推出了第二种video标签的格式 如何查看 ...

  2. 微信小程序禁止刷新之后苹果端还可以下拉的问题

    一.问题描述 最近在做一个小程序项目,需要禁止下拉刷新,于是在page.json里面添加了这段话 "enablePullDownRefresh":false 全局关闭下拉刷新,这段 ...

  3. 协议类接口 - SPI

     一.SPI概述 SPI(Serial Peripheral Interface,串行外设接口)总线系统是一种同步串行外设接口,它可以使CPU与各种外围设备以串行方式进行通信以交换信息.一般主控SoC ...

  4. 个人开发者即时到账收款方案 BufPay.com

    BufPay 个人即时到账支付平台 前言 作为独立开发者,一般只有一个人独立奋战,做出了产品需要收款是非常麻烦的,接入支付宝微信支付都需要公司公户,而注册公司.开公户等一系列操作非常麻烦,成本也很高一 ...

  5. MySQL必知必会 读书笔记四:数据过滤

    过滤数据 WHERE 只检索所需数据需要指定搜索条件( search criteria) ,搜索条件也称为过滤条件( filtercondition) . 在SELECT语句中,数据根据WHERE子句 ...

  6. thinkphp5配置讲解

    一.thinkphp配置类型有哪些? 1.在thinkphp中,有6种配置.即惯例配置,应用配置.扩展配置.模块配置.场景配置.动态配置. 2.惯例配置就是系统默认的配置. 3.应用配置就是我们自己开 ...

  7. jdk11新特性

    JDK 11主要特性一览 jdk11即将在9月25号发布正式版.确定的新特性包括以下17个 181 嵌套类可见性控制 309 动态文件常量 315 改进 Aarch64 Intrinsics 318 ...

  8. freeswitch对话机320信令在专有网络情况下不生效的处理

    昨天处理客户提出的话机设置呼叫转移不生效的问题, 经过多次测试发现这个问题与freeswitch版本和配置没有关系, 后来分析freeswitch正常转移日志与不转移日志发现不转移的日志少了一行 Re ...

  9. 在线预览word,excel文档

    Google Doc 示例:https://jsfiddle.net/7xr419yb/ Microsoft Office 示例:https://jsfiddle.net/gcuzq343/

  10. python爬xx图代码

    今日 好热,照样是挖洞挖不到,看了几天的python爬虫,学会了xpath解析 撸一个代码玩玩] 不要说什么,优化之类的,刚学完,跑了一阵 ,还可以  挺稳定 # -*- coding:utf-8 - ...