LSTM是什么

LSTM即Long Short Memory Network,长短时记忆网络。它其实是属于RNN的一种变种,可以说它是为了克服RNN无法很好处理远距离依赖而提出的。

我们说RNN不能处理距离较远的序列是因为训练时很有可能会出现梯度消失,即通过下面的公式训练时很可能会发生指数缩小,让RNN失去了对较远时刻的感知能力。

∂E∂W=∑t∂Et∂W=∑tk=0∂Et∂nett∂nett∂st(∏tj=k+1∂st∂sk)∂sk∂W

解决思路

RNN梯度消失不应该是由我们学习怎么去避免,而应该通过改良让循环神经网络自己具备避免梯度消失的特性,从而让循环神经网络自身具备处理长期序列依赖的能力。

RNN的状态计算公式为St=f(St−1,xt),根据链式求导法则会导致梯度变为连乘的形式,而sigmoid小于1会让连乘小得很快。为了解决这个问题,科学家采用了累加的形式,St=∑tτ=1ΔSτ,其导数也为累加,从而避免梯度消失。LSTM即是使用了累加形式,但它的实现较复杂,下面进行介绍。

LSTM模型

回顾一下RNN的模型,如下图,展开后多个时刻隐层互相连接,而所有循环神经网络都有一个重复的网络模块,RNN的重复网络模块很简单,如下下图,比如只有一个tanh层。

而LSTM的重复网络模块的结构则复杂很多,它实现了三个门计算,即遗忘门、输入门和输出门。每个门负责是事情不一样,遗忘门负责决定保留多少上一时刻的单元状态到当前时刻的单元状态;输入门负责决定保留多少当前时刻的输入到当前时刻的单元状态;输出门负责决定当前时刻的单元状态有多少输出。

每个LSTM包含了三个输入,即上时刻的单元状态、上时刻LSTM的输出和当前时刻输入。

LSTM的机制

根据上图咱们一步一步来看LSTM神经网络是怎么运作的。

首先看遗忘门,用来计算哪些信息需要忘记,通过sigmoid处理后为0到1的值,1表示全部保留,0表示全部忘记,于是有

ft=σ(Wf⋅[ht−1,xt]+bf)

其中中括号表示两个向量相连合并,Wf是遗忘门的权重矩阵,σ为sigmoid函数,bf为遗忘门的偏置项。设输入层维度为dx,隐藏层维度为dh,上面的状态维度为dc,则Wf的维度为dc×(dh+dx)。

其次看输入门,输入门用来计算哪些信息保存到状态单元中,分两部分,第一部分为

it=σ(Wi⋅[ht−1,xt]+bi)

该部分可以看成当前输入有多少是需要保存到单元状态的。第二部分为

c~t=tanh(Wc⋅[ht−1,xt]+bc)

该部分可以看成当前输入产生的新信息来添加到单元状态中。结合这两部分来创建一个新记忆。

而当前时刻的单元状态由遗忘门输入和上一时刻状态的积加上输入门两部分的积,即

ct=ft∗ct−1+it∗c~t

最后看看输出门,通过sigmoid函数计算需要输出哪些信息,再乘以当前单元状态通过tanh函数的值,得到输出。

ot=σ(Wo⋅[ht−1,xt]+bo)

ht=ot∗tanh(ct)

LSTM的训练

化繁为简,这里只讨论包含一个LSTM层的三层神经网络(如果有多个层则误差项除了沿时间反向传播外,还会向上一层传播),LSTM向前传播时与三个门相关的公式如下,

ft=σ(Wf⋅[ht−1,xt]+bf)

it=σ(Wi⋅[ht−1,xt]+bi)

c~t=tanh(Wc⋅[ht−1,xt]+bc)

ct=ft∗ct−1+it∗c~t

ot=σ(Wo⋅[ht−1,xt]+bo)

ht=ot∗tanh(ct)

需要学习的参数挺多的,同时也可以看到LSTM的输出ht有四个输入分量加权影响,即三个门相关的ftitc~tot,而且其中权重W都是拼接的,所以在学习时需要分割出来,即

Wf=Wfx+Wfh

Wi=Wix+Wih

Wc~=Wc~x+Wc~h

Wo=Wox+Woh

输出层的输入yit=Wyiht,输出为yot=σ(yit)。

设某时刻的损失函数为Et=12(yd−yot)2,则某样本的损失为

E=∑Tt=1Et

设当前时刻t的误差项δt=∂E∂ht,那么误差沿着时间反向传递则需要计算t-1时刻的误差项δt−1,则

δt−1=∂E∂ht−1=∂E∂ht∂ht∂ht−1=δt∂ht∂ht−1

LSTM的输出ht可看成是一个复合函数,f[ft(ht−1),it(ht−1),c~t(ht−1),ot(ht−1)],由全导数公式有,

∂ht∂ht−1=∂ht∂ct∂ct∂ft∂ft∂netf,t∂netf,t∂ht−1+∂ht∂ct∂ct∂it∂it∂neti,t∂neti,t∂ht−1+∂ht∂ct∂ct∂c~t∂c~t∂netc~,t∂netc~,t∂ht−1+∂ht∂ot∂ot∂neto,t∂neto,t∂ht−1

其中netf,tneti,tnetc~,tneto,t表示对应函数的输入。将上述所有偏导都求出来,

∂ht∂ct=ot∗(1−tanh(ct)2)∂ct∂ft=ct−1∂ft∂netf,t=ft∗(1−ft)∂netf,t∂ht−1=Wfh

同样地,其他也可以求出来,最后得到t时刻和t-1时刻之间的关系。再设

δf,t=∂E∂netf,tδi,t=∂E∂neti,tδc~,t=∂E∂netc~,tδo,t=∂E∂neto,t

得到,

δt−1=δf,tWfh+δi,tWih+δc~,tWch+δo,tWoh

接着对某时刻t的所有权重进行求偏导,

∂E∂Wfh,t=∂E∂netf,t∂netf,t∂Wfh,t=δf,tht−1
∂E∂Wih,t=∂E∂neti,t∂neti,t∂Wih,t=δi,tht−1
∂E∂Wch,t=∂E∂netc~,t∂netc~,t∂Wch,t=δc~,tht−1
∂E∂Woh,t=∂E∂neto,t∂neto,t∂Woh,t=δo,tht−1
∂E∂Wfx=∂E∂netf,t∂netf,t∂Wfx=δf,txt
∂E∂Wix=∂E∂neti,t∂neti,t∂Wix=δi,txt
∂E∂Wcx=∂E∂netc~,t∂netc~,t∂Wcx=δc~,txt
∂E∂Wox=∂E∂neto,t∂neto,t∂Wox=δo,txt
∂E∂bo,t=∂E∂neto,t∂neto,t∂bo,t=δo,t
∂E∂bf,t=∂E∂netf,t∂netf,t∂bf,t=δf,t
∂E∂bi,t=∂E∂neti,t∂neti,t∂bi,t=δi,t
∂E∂bc,t=∂E∂netc~,t∂netc~,t∂bc,t=δc~,t

对于整个样本,它的误差是所有时刻的误差之和,而与上个时刻相关的权重的梯度等于所有时刻的梯度之和,其他权重则不必累加,最终得到

∂E∂Wfh=∑j=1tδf,jhj−1
∂E∂Wih=∑j=1tδi,jhj−1
∂E∂Wch=∑j=1tδc~,jhj−1
∂E∂Woh=∑j=1tδo,jhj−1
∂E∂bf=∑j=1tδf,j
∂E∂bi=∑j=1tδi,j
∂E∂bc=∑j=1tδc~,j
∂E∂bo=∑j=1tδo,j
∂E∂Wfx=∂E∂netf,t∂netf,t∂Wfx=δf,txt
∂E∂Wix=∂E∂neti,t∂neti,t∂Wix=δi,txt
∂E∂Wcx=∂E∂netc~,t∂netc~,t∂Wcx=δc~,txt
∂E∂Wox=∂E∂neto,t∂neto,t∂Wox=δo,txt

相关阅读:

循环神经网络

卷积神经网络

机器学习之神经网络

机器学习之感知器

神经网络的交叉熵损失函数

========广告时间========

鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以到 https://item.jd.com/12185360.html 进行预定。感谢各位朋友。

为什么写《Tomcat内核设计剖析》

=========================

欢迎关注:

LSTM神经网络的更多相关文章

  1. (转) 干货 | 图解LSTM神经网络架构及其11种变体(附论文)

    干货 | 图解LSTM神经网络架构及其11种变体(附论文) 2016-10-02 机器之心 选自FastML 作者:Zygmunt Z. 机器之心编译  参与:老红.李亚洲 就像雨季后非洲大草原许多野 ...

  2. (转)LSTM神经网络介绍

    原文链接:http://www.atyun.com/16821.html 扩展阅读: https://machinelearningmastery.com/time-series-prediction ...

  3. LSTM 神经网络输入输出层

    今天终于弄明白,TensorFlow和Keras中LSTM神经网络的输入输出层到底应该怎么设置和连接了.写个备忘. https://machinelearningmastery.com/how-to- ...

  4. tensorflow学习之(十一)RNN+LSTM神经网络的构造

    #RNN 循环神经网络 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data tf.se ...

  5. 深入浅出LSTM神经网络

    转自:https://www.csdn.net/article/2015-06-05/2824880 LSTM递归神经网络RNN长短期记忆   摘要:根据深度学习三大牛的介绍,LSTM网络已被证明比传 ...

  6. Tensorflow之基于LSTM神经网络写唐诗

    最近看了不少关于写诗的博客,在前人的基础上做了一些小的改动,因比较喜欢一次输入很长的开头句,所以让机器人输出压缩为一个开头字生成两个诗句,写五言和七言诗,当然如果你想写更长的诗句是可以继续改动的. 在 ...

  7. LSTM神经网络走读

      0设计概述 RNN梯度爆炸和消失比较严重,RNN隐层只有一个状态h记录短期记忆,增加一个长期记忆状态c似乎就可以解决问题.

  8. LSTM神经网络输入输出究竟是怎样的?

    LSTM图和词向量输入分析

  9. 深度神经网络在量化交易里的应用 之二 -- 用深度网络(LSTM)预测5日收盘价格

        距离上一篇文章,正好两个星期. 这边文章9月15日 16:30 开始写. 可能几个小时后就写完了.用一句粗俗的话说, "当你怀孕的时候,别人都知道你怀孕了, 但不知道你被日了多少回 ...

随机推荐

  1. annovar积累

    20170222 ANNOVAR简介 ANNOVAR是由王凯编写的一个注释软件,可以对SNP和indel进行注释,也可以进行变异的过滤筛选. ANNOVAR能够利用最新的数据来分析各种基因组中的遗传变 ...

  2. ubuntu server 16.04(amd 64) 配置网桥,多网卡使用激活

    安装了Ubuntu16.04的server版本,结果进入系统输入ifconfig后发现,只有一个网卡enp1s0,还有一个网络回路lo,ifconfig -a 发现其实一共有四个网卡,enp1s0,e ...

  3. SQLServer中获取所有数据库名、所有表名、所有字段名的SQL语句

    ----1. 获取所有的数据库名----- SELECT NAME FROM MASTER.DBO.SYSDATABASES ORDER BY NAME -----2. 获取所有的表名------ S ...

  4. python sort、sorted

    1. (1).sorted()方法返回一个新列表(默认升序). list.sort() (2).另一个不同:list.sort()方法仅被定义在list中,sorted()方法对所有的可迭代序列都有效 ...

  5. 【转】DrawDibDraw

    http://blog.csdn.net/normallife/article/details/53177315 BMP位图文件结构及平滑缩放 用普通方法显示BMP位图,占内存大,速度慢,在图形缩小时 ...

  6. STL_算法_06_遍历算法

    ◆ 常用的遍历算法: 1.1.用指定函数依次对指定范围内所有元素进行迭代访问.该函数不得修改序列中的元素 functor for_each(iteratorBegin, iteratorEnd, fu ...

  7. Java JDK5新特性-静态导入

    2017-10-31 00:10:50 静态导入格式:import static 包名 ...类名.方法名: 也就说可以直接导入到方法名. 注意: 方法必须是静态的 如果有多个同名的静态方法,容易不知 ...

  8. 雷林鹏分享:C# 集合(Collection)

    C# 集合(Collection) 集合(Collection)类是专门用于数据存储和检索的类.这些类提供了对栈(stack).队列(queue).列表(list)和哈希表(hash table)的支 ...

  9. English trip -- Review Unit 9 Daily living 日常生活

    主要讲了一个时态:现在进行时   Be动词+Ving  需要记住的有6种规律 1.直接单词后面 + ing    e.g.     watch -> watching 2.是ie结尾的单词,变y ...

  10. py to exe —— pywin32

    xu言: 最近研究python,觉得做些windows小工具还挺好玩,就研究了下py代码如何转成exe 这里用到一个工具 pywin32 https://sourceforge.net/project ...