一 RNN概述
    前面我们叙述了BP算法, CNN算法, 那么为什么还会有RNN呢?? 什么是RNN, 它到底有什么不同之处? RNN的主要应用领域有哪些呢?这些都是要讨论的问题.

1) BP算法,CNN之后, 为什么还有RNN?

细想BP算法,CNN(卷积神经网络)我们会发现, 他们的输出都是只考虑前一个输入的影响而不考虑其它时刻输入的影响, 比如简单的猫,狗,手写数字等单个物体的识别具有较好的效果. 但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了.因此, RNN就应运而生了.

2) 什么是RNN?

RNN是一种特殊的神经网络结构, 它是根据"人的认知是基于过往的经验和记忆"这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种'记忆'功能.

RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

3) RNN的主要应用领域有哪些呢?

RNN的应用领域有很多, 可以说只要考虑时间先后顺序的问题都可以使用RNN来解决.这里主要说一下几个常见的应用领域:

① 自然语言处理(NLP): 主要有视频处理, 文本生成, 语言模型, 图像处理

② 机器翻译, 机器写小说

③ 语音识别

④ 图像描述生成

⑤ 文本相似度计算

⑥ 音乐推荐、网易考拉商品推荐、Youtube视频推荐等新的应用领域.

二 RNN(循环神经网络)
    1) RNN模型结构
    前面我们说了RNN具有时间"记忆"的功能, 那么它是怎么实现所谓的"记忆"的呢?

图1 RNN结构图

如图1所示, 我们可以看到RNN层级结构较之于CNN来说比较简单, 它主要有输入层,Hidden Layer, 输出层组成.

并且会发现在Hidden Layer 有一个箭头表示数据的循环更新, 这个就是实现时间记忆功能的方法.

如果到这里你还是没有搞懂RNN到底是什么意思,那么请继续往下看!

图2 Hidden Layer的层级展开图

如图2所示为Hidden Layer的层级展开图. t-1, t, t+1表示时间序列. X表示输入的样本. St表示样本在时间t处的的记忆,St = f(W*St-1 +U*Xt). W表示输入的权重, U表示此刻输入的样本的权重, V表示输出的样本权重.

在t =1时刻, 一般初始化输入S0=0, 随机初始化W,U,V, 进行下面的公式计算:

其中,f和g均为激活函数. 其中f可以是tanh,relu,sigmoid等激活函数,g通常是softmax也可以是其他。

时间就向前推进,此时的状态s1作为时刻1的记忆状态将参与下一个时刻的预测活动,也就是:

以此类推, 可以得到最终的输出值为:

注意: 1. 这里的W,U,V在每个时刻都是相等的(权重共享).

2. 隐藏状态可以理解为:  S=f(现有的输入+过去记忆总结)

2) RNN的反向传播
    前面我们介绍了RNN的前向传播的方式, 那么RNN的权重参数W,U,V都是怎么更新的呢?

每一次的输出值Ot都会产生一个误差值Et, 则总的误差可以表示为:.

则损失函数可以使用交叉熵损失函数也可以使用平方误差损失函数.

由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.(不熟悉BP的可以参考这里)

也就是要求参数的梯度:

首先我们求解W的更新方法, 由前面的W的更新可以看出它是每个时刻的偏差的偏导数之和.

在这里我们以 t = 3时刻为例, 根据链式求导法则可以得到t = 3时刻的偏导数为:

此时, 根据公式我们会发现, S3除了和W有关之外, 还和前一时刻S2有关.

对于S3直接展开得到下面的式子:

对于S2直接展开得到下面的式子:

对于S1直接展开得到下面的式子:

将上述三个式子合并得到:

这样就得到了公式:

这里要说明的是:表示的是S3对W直接求导, 不考虑S2的影响.(也就是例如y = f(x)*g(x)对x求导一样)

其次是对U的更新方法. 由于参数U求解和W求解类似,这里就不在赘述了,最终得到的具体的公式如下:

最后,给出V的更新公式(V只和输出O有关):

三 RNN的一些改进算法
    前面我们介绍了RNN的算法, 它处理时间序列的问题的效果很好, 但是仍然存在着一些问题, 其中较为严重的是容易出现梯度消失或者梯度爆炸的问题(BP算法和长时间依赖造成的). 注意: 这里的梯度消失和BP的不一样,这里主要指由于时间过长而造成记忆值较小的现象.

因此, 就出现了一系列的改进的算法, 这里介绍主要的两种算法: LSTM 和 GRU.

LSTM 和 GRU对于梯度消失或者梯度爆炸的问题处理方法主要是:

对于梯度消失: 由于它们都有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题。

对于梯度爆炸:用来克服梯度爆炸的问题就是gradient clipping,也就是当你计算的梯度超过阈值c或者小于阈值-c的时候,便把此时的梯度设置成c或-c。

1) LSTM算法(Long Short Term Memory, 长短期记忆网络 ) --- 重要的目前使用最多的时间序列算法

图3 LSTM算法结构图

如图3为LSTM算法的结构图.

和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.

用于遗忘的门叫做"遗忘门", 用于信息增加的叫做"信息增加门",最后是用于输出的"输出门". 这里就不展开介绍了.

此外,LSTM算法的还有一些变种.

如图4所示, 它增加“peephole connections”层 , 让门层也接受细胞状态的输入.

图4 LSTM算法的一个变种

如图5所示为LSTM的另外一种变种算法.它是通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增加什么信息,而是一起进行考虑。

图5 LSTM算法的一个变种

2) GRU算法

GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.

其各个部分满足关系式如下:

四 基于Tensorflow的基本操作和总结
    使用tensorflow的基本操作如下:

# _*_coding:utf-8_*_

import tensorflow as tf
import numpy as np

'''
TensorFlow中的RNN的API主要包括以下两个路径:
1) tf.nn.rnn_cell(主要定义RNN的几种常见的cell)
2) tf.nn(RNN中的辅助操作)
'''
# 一 RNN中的cell
# 基类(最顶级的父类): tf.nn.rnn_cell.RNNCell()
# 最基础的RNN的实现: tf.nn.rnn_cell.BasicRNNCell()
# 简单的LSTM cell实现: tf.nn.rnn_cell.BasicLSTMCell()
# 最常用的LSTM实现: tf.nn.rnn_cell.LSTMCell()
# RGU cell实现: tf.nn.rnn_cell.GRUCell()
# 多层RNN结构网络的实现: tf.nn.rnn_cell.MultiRNNCell()

# 创建cell
# cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
# print(cell.state_size)
# print(cell.output_size)

# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
# inputs = tf.placeholder(dtype=tf.float32, shape=[4, 64])

# 给定RNN的初始状态
# s0 = cell.zero_state(4, tf.float32)
# print(s0.get_shape())

# 对于t=1时刻传入输入和state0,获取结果值
# output, s1 = cell.call(inputs, s0)
# print(output.get_shape())
# print(s1.get_shape())

# 定义LSTM cell
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
inputs = tf.placeholder(tf.float32, shape=[4, 48])
# 给定初始状态
s0 = lstm_cell.zero_state(4, tf.float32)
# 对于t=1时刻传入输入和state0,获取结果值
output, s1 = lstm_cell.call(inputs, s0)
print(output.get_shape())
print(s1.h.get_shape())
print(s1.c.get_shape())
    当然, 你可能会发现使用cell.call()每次只能调用一个得到一个状态, 如有多个状态需要多次重复调用较为麻烦, 那么我们怎么解决的呢? 可以参照后面的基于RNN的手写数字识别和单词预测的实例查找解决方法.

本文主要介绍了一种时间序列的RNN神经网络及其基础上衍生出来的变种算法LSTM和GRU算法, 也对RNN算法的使用场景作了介绍.

当然, 由于篇幅限制, 这里对于双向RNNs和多层的RNNs没有介绍. 另外, 对于LSTM的参数更新算法在这里也没有介绍, 后续补上吧!

最后, 如果你发现了任何问题, 欢迎一起探讨, 共同进步!!
---------------------

RNN概述-深度学习 -神经网络的更多相关文章

  1. 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD

    上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...

  2. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  5. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  6. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

  7. GitHub 上 57 款最流行的开源深度学习项目

    转载:https://www.oschina.net/news/79500/57-most-popular-deep-learning-project-at-github GitHub 上 57 款最 ...

  8. Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包

    Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包 用 CNTK 搞深度学习 (一) 入门 Computational Network Toolk ...

  9. Github上Stars最多的53个深度学习项目,TensorFlow遥遥领先

    原文:https://github.com/aymericdamien/TopDeepLearning 项目名称 Stars 项目介绍 TensorFlow 29622 使用数据流图计算可扩展机器学习 ...

随机推荐

  1. MySql is marked as crashed and should be repaired问题

    在一次电脑不知道为什么重启之后数据库某表出现了 is marked as crashed and should be repaired这个错误,百度了一下,很多都是去找什么工具然后输入命令之类的,因为 ...

  2. nopCommerce电子商务平台 安装教程(图文)

    nopCommerce是一个通用的电子商务平台,适合每个商家的需要:它强大的企业和小型企业网站遍布世界各地的公司销售实体和数字商品.nopCommerce是一个透明且结构良好的解决方案,它结合了开源和 ...

  3. Wireshark工具抓包的数据包分析

    Wireshark(前称Ethereal)是一个网络封包分析软件.网络封包分析软件的功能是撷取网络封包,并尽可能显示出最为详细的网络封包资料. Wireshark使用WinPCAP作为接口,直接与网卡 ...

  4. webpack新建项目

    记录如何搭建一个最简单的能跑的项目! 1.首先,需要下载安装nodejs环境,可以直接百度搜索nodejs去官网下载符合你操作系统的环境. 安装完nodejs后,在控制台输入命令: npm -vers ...

  5. PHP基础2--基本语法

    主要: 标记符,注释 变量 常量 数据类型 运算符 流程控制 标记符,注释 4种标记符号: 1.  默认形式:  <?php    php语句      ?> 如果<?php ... ...

  6. Laravel 入门笔记

    1.MVC简介 MVC全名是Model View Controller,是模型-视图-控制器的缩写 Model是应用程序中用于处理应用程序数据逻辑的部分 View是应用程序中处理数据显示的部分 Con ...

  7. Java使用zxing生成解读QRcode二维码

    1.maven的pom配置jar包,如果不实用maven请手动下载jar包 <dependency> <groupId>com.google.zxing</groupId ...

  8. 第7章 YARN HA配置

    目录 7.1 yarn-site.xm文件配置 7.2 测试YARN自动故障转移 ResourceManager (RM)负责跟踪集群中的资源,以及调度应用程序(例如,MapReduce作业).在Ha ...

  9. 关于485通信不稳定问题解决方案[STM32产品问题]

    485通讯不稳定的问题(具体表现为有时能通讯上,有时通讯不上) RS485在连接设备过多.通讯距离过长.双绞线质量差,接线不规范等,都会导致通讯不稳定的问题. 解决方案: 一.关于485总线的几个概念 ...

  10. Python学习 :迭代器&生成器

    列表生成式 列表生成式的操作顺序: 1.先依次来读取元素 for x 2.对元素进行操作 x*x 3.赋予变量 Eg.列表生成式方式一 a = [x*x for x in range(10)] pri ...