RNN 与 LSTM 的原理详解
原文地址:https://blog.csdn.net/happyrocking/article/details/83657993
什么是序列呢?序列是一串有顺序的数据,比如某一条数据为 [x1,x2,x3,x4][x1,x2,x3,x4] [x_1, x_2, x_3, x_4][x1,x2,x3,x4],其中每个元素可以是一个字符、一个单词、一个向量,甚至是一个声音。比如:
语音处理。此时,每个元素是每帧的声音信号。
时间序列问题。例如每天的股票价格等。
RNN 的结构
我们从基础的神经网络中知道,神经网络包含输入层、隐层、输出层,通过激活函数控制输出,层与层之间通过权值连接。激活函数是事先确定好的,那么神经网络模型通过训练“学“到的东西就蕴含在“权值“中。单层的神经网络如图:
基础的神经网络只在层与层之间建立了权连接,RNN最大的不同之处就是在层之间的神经元之间也建立的权连接。如图。
在展开结构中我们可以观察到,在标准的RNN结构中,隐层的神经元之间也是带有权值的。也就是说,随着序列的不断推进,前面的隐层将会影响后面的隐层。图中O代表输出,y代表样本给出的确定值,L代表损失函数,我们可以看到,“损失“也是随着序列的推荐而不断积累的。
除上述特点之外,标准RNN的还有以下特点:
每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。
然而在实际中这一种结构并不能解决所有问题,常见的变种有:
1、多输入单输出
有的时候,我们要处理的问题输入是一个序列,输出是一个单独的值而不是序列,应该怎样建模呢?实际上,我们只在最后一个h上进行输出变换就可以了:
2、单输入多输出
输入不是序列而输出为序列的情况怎么处理?我们可以只在序列开始进行输入计算,其余只需要隐层状态进行传递。
从类别生成语音或音乐等
实际中,还有另外一种多输入多输出的结构,其输入与输出并不是一一对应的,如图:
Encoder-Decoder结构先将输入数据编码成一个上下文向量c。得到c有多种方式,最简单的方法就是把Encoder的最后一个隐状态赋值给c,还可以对最后的隐状态做一个变换得到c,也可以对所有的隐状态做变换。
拿到c之后,就用另一个RNN网络对其进行解码,这部分RNN网络被称为Decoder。具体做法就是将c当做之前的初始状态h0输入到Decoder中。
还有另外一种 Decoder,是将c当做每一步的输入:
文本摘要。输入是一段文本序列,输出是这段文本序列的摘要序列。
阅读理解。将输入的文章和问题分别编码,再对其进行解码得到问题的答案。
语音识别。输入是语音信号序列,输出是文字序列。
下面对多输入多输出(一一对应)的经典结构作分析:
前向传播算法其实非常简单,对于t时刻,隐层单元为:
h(t)=f(Ux(t)+Wh(t−1)+b)h(t)=f(Ux(t)+Wh(t−1)+b) h^{(t)}=f(Ux^{(t)}+Wh^{(t-1)}+b)h(t)=f(Ux(t)+Wh(t−1)+b)
其中,f 为激活函数,如 sigmoid、tanh 等,b 为偏置。
t时刻的输出为:
o(t)=Vh(t)+co(t)=Vh(t)+c o^{(t)}=Vh^{(t)}+co(t)=Vh(t)+c
RNN的训练方法
BPTT(back-propagation through time)算法是常用的训练RNN的方法,其实本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。综上所述,BPTT算法本质还是BP算法,BP算法本质还是梯度下降法,那么求各个参数的梯度便成了此算法的核心。
∂L(t)∂V=∂L(t)∂o(t)∂o(t)∂V∂L(t)∂V=∂L(t)∂o(t)∂o(t)∂V \frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}∂V∂L(t)=∂o(t)∂L(t)∂V∂o(t)
RNN的损失也是会随着时间累加的,所以需要求出所有时刻的偏导然后求和:
L=∑nt=1L(t)L=∑t=1nL(t) L=\sum_{t=1}^n L^{(t)}L=t=1∑nL(t)
∂L∂V=∑nt=1∂L(t)∂o(t)∂o(t)∂V∂L∂V=∑t=1n∂L(t)∂o(t)∂o(t)∂V \frac{\partial L}{\partial V}=\sum_{t=1}^n\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}∂V∂L=t=1∑n∂o(t)∂L(t)∂V∂o(t)
W和U的偏导的求解由于需要涉及到历史数据,其偏导求起来相对复杂,我们先假设只有三个时刻,那么在第三个时刻 L对W的偏导数为:
∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂W∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂W∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂W∂h(1)
同理,对U的偏导为:
∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂U∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂U \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U}∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂U∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂U∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂U∂h(1)
可以看到,在某个时刻的对W或是U的偏导数,需要追溯这个时刻之前所有时刻的信息,这还仅仅是一个时刻的偏导数,上面说过损失也是会累加的,那么整个损失函数对W和U的偏导数将会非常繁琐。虽然如此但好在规律还是有迹可循,我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式:
∂L(t)∂W=∂L(t)∂o(t)∂o(t)∂h(t)∑tk=1(∏ti=k+1∂h(i)∂h(i−1))∂h(k)∂W∂L(t)∂W=∂L(t)∂o(t)∂o(t)∂h(t)∑k=1t(∏i=k+1t∂h(i)∂h(i−1))∂h(k)∂W \frac{\partial L^{(t)}}{\partial W}= \frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial h^{(t)}}\sum_{k=1}^t(\prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}})\frac{\partial h^{(k)}}{\partial W}∂W∂L(t)=∂o(t)∂L(t)∂h(t)∂o(t)k=1∑t(i=k+1∏t∂h(i−1)∂h(i))∂W∂h(k)
∂L(t)∂U=∂L(t)∂o(t)∂o(t)∂h(t)∑tk=1(∏ti=k+1∂h(i)∂h(i−1))∂h(k)∂U∂L(t)∂U=∂L(t)∂o(t)∂o(t)∂h(t)∑k=1t(∏i=k+1t∂h(i)∂h(i−1))∂h(k)∂U \frac{\partial L^{(t)}}{\partial U}= \frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial h^{(t)}}\sum_{k=1}^t(\prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}})\frac{\partial h^{(k)}}{\partial U}∂U∂L(t)=∂o(t)∂L(t)∂h(t)∂o(t)k=1∑t(i=k+1∏t∂h(i−1)∂h(i))∂U∂h(k)
与V相同,对W或U的整体偏导,也是将所有t时刻的偏导相加。
前面说过激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:
∏ti=k+1∂h(i)∂h(i−1)=∏ti=k+1tanh′⋅W∏i=k+1t∂h(i)∂h(i−1)=∏i=k+1ttanh′⋅W \prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}}=\prod_{i=k+1}^{t}tanh'·Wi=k+1∏t∂h(i−1)∂h(i)=i=k+1∏ttanh′⋅W
或
∏ti=k+1∂h(i)∂h(i−1)=∏ti=k+1sigmoid′⋅W∏i=k+1t∂h(i)∂h(i−1)=∏i=k+1tsigmoid′⋅W \prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}}=\prod_{i=k+1}^{t}sigmoid'·Wi=k+1∏t∂h(i−1)∂h(i)=i=k+1∏tsigmoid′⋅W
我们会发现累乘会导致激活函数导数和权重矩阵的累乘,进而会导致“梯度消失“和“梯度爆炸“现象的发生。
为什么会出现“梯度消失“?我们先来看看这两个激活函数的图像:
同理,由于权重矩阵的累乘,可能会导致“梯度爆炸”的发生。
RNN的特点本来就是能“追根溯源“利用历史数据,现在告诉我可利用的历史数据竟然是有限的,这就令人非常难受,解决“梯度消失“是非常必要的。解决“梯度消失“的方法主要有:
改变传播结构
关于第二点,LSTM结构可以解决这个问题。
总结一下,sigmoid函数的缺点:
sigmoid函数不是0中心对称,tanh函数是,可以使网络收敛的更好。
下面来了解一下LSTM(long short-term memory)。长短期记忆网络是RNN的一种变体,RNN由于梯度消失的原因只能有短期记忆,LSTM网络通过精妙的门控制将短期记忆与长期记忆结合起来,并且一定程度上解决了梯度消失的问题。
长期依赖(Long-Term Dependencies)问题
RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,例如使用过去的视频段来推测对当前段的理解。如果 RNN 可以做到这个,他们就变得非常有用。但是真的可以么?答案是,还有很多依赖因素。
有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果我们试着预测 “the clouds are in the sky” 最后的词,我们并不需要任何其他的上下文 —— 因此下一个词很显然就应该是 sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。
不幸的是,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。
LSTM网络
Long Short Term 网络—— 一般就叫做 LSTM ——是一种 RNN 特殊的类型,可以学习长期依赖信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。在很多问题,LSTM 都取得相当巨大的成功,并得到了广泛的使用。
LSTM 通过刻意的设计来避免长期依赖问题。记住长期的信息在实践中是 LSTM 的默认行为,而非需要付出很大代价才能获得的能力!
所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。
粉色的圆形表示一些运算操作,诸如加法乘法
黑色的单箭头表示向量的传输
两个箭头合成一个表示向量的连接
一个箭头分开表示向量的复制
细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。
LSTM 拥有三个门,来保护和控制细胞状态。
理解LSTM的三个门
遗忘门
在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为遗忘门完成。该门会读取ht−1ht−1 h_{t-1}ht−1和xtxt x_txt,输出一个在 0 到 1 之间的数值给每个在细胞状态Ct−1Ct−1 C_{t-1}Ct−1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。
让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。
对于第一个问题,“遗忘“可以理解为“之前的内容记住多少“,其精髓在于只能输出(0,1)小数的sigmoid函数和粉色圆圈的乘法,LSTM网络经过学习决定让网络记住以前百分之多少的内容。对于第二个问题就更好理解,决定记住什么遗忘什么,其中新的输入肯定要产生影响。
输入门
下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分。第一,sigmoid 层称 “输入门层” 决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量,Ct˜Ct~ \tilde{C_{t}}Ct~,会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。
在我们语言模型的例子中,我们希望增加新的主语的性别到细胞状态中,来替代旧的需要忘记的主语。
我们把旧状态与ftft f_tft相乘,丢弃掉我们确定需要丢弃的信息。接着加上it⋅Ct˜it⋅Ct~ i_t·\tilde{C_t}it⋅Ct~。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。
有了上面的理解基础输入门,输入门理解起来就简单多了。tanh函数创建新的输入值,sigmoid函数决定可以输入进去的比例。
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。
LSTM的变体
我们到目前为止都还在介绍正常的 LSTM。但是不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。差异非常小,但是也值得拿出来讲一下。
窥视孔连接
一个流行的LSTM变种,由Gers & Schmidhuber (2002提出,加入了“窥视孔连接(peephole connection)”。也就是说我们让各种门可以接受到细胞状态的输入。
对偶忘记门和输入门
另一个变体是通过使用对偶(coupled)忘记门和输入门。不同于之前是分开确定什么忘记和需要添加什么新的信息,这里是一同做出决定。我们仅仅会当我们将要输入在当前位置时忘记。我们仅仅输入新的值到那些我们已经忘记旧的信息的那些状态 。
另一个改动较大的变体是 Gated Recurrent Unit (GRU),这是由 Cho, et al. (2014) 提出。它将忘记门和输入门合成了一个单一的 更新门。同样还混合了细胞状态和隐藏状态,和其他一些改动。最终的模型比标准的 LSTM 模型要简单,也是非常流行的变体。
要问哪个变体是最好的?其中的差异性真的重要吗?Greff, et al. (2015) 给出了流行变体的比较,结论是他们基本上是一样的。Jozefowicz, et al. (2015) 则在超过 1 万种 RNN 架构上进行了测试,发现一些架构在某些任务上也取得了比 LSTM 更好的结果。
————————————————
版权声明:本文为CSDN博主「HappyRocking」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/happyrocking/article/details/83657993
RNN 与 LSTM 的原理详解的更多相关文章
- (数据科学学习手札39)RNN与LSTM基础内容详解
一.简介 循环神经网络(recurrent neural network,RNN),是一类专门用于处理序列数据(时间序列.文本语句.语音等)的神经网络,尤其是可以处理可变长度的序列:在与传统的时间序列 ...
- I2C 基础原理详解
今天来学习下I2C通信~ I2C(Inter-Intergrated Circuit)指的是 IC(Intergrated Circuit)之间的(Inter) 通信方式.如上图所以有很多的周边设备都 ...
- Zigbee组网原理详解
Zigbee组网原理详解 来源:互联网 作者:佚名2015年08月13日 15:57 [导读] 组建一个完整的zigbee网状网络包括两个步骤:网络初始化.节点加入网络.其中节点加入网络又包括两个 ...
- 块级格式化上下文(block formatting context)、浮动和绝对定位的工作原理详解
CSS的可视化格式模型中具有一个非常重要地位的概念——定位方案.定位方案用以控制元素的布局,在CSS2.1中,有三种定位方案——普通流.浮动和绝对定位: 普通流:元素按照先后位置自上而下布局,inli ...
- SSL/TLS 原理详解
本文大部分整理自网络,相关文章请见文后参考. SSL/TLS作为一种互联网安全加密技术,原理较为复杂,枯燥而无味,我也是试图理解之后重新整理,尽量做到层次清晰.正文开始. 1. SSL/TLS概览 1 ...
- 锁之“轻量级锁”原理详解(Lightweight Locking)
大家知道,Java的多线程安全是基于Lock机制实现的,而Lock的性能往往不如人意. 原因是,monitorenter与monitorexit这两个控制多线程同步的bytecode原语,是JVM依赖 ...
- [转]js中几种实用的跨域方法原理详解
转自:js中几种实用的跨域方法原理详解 - 无双 - 博客园 // // 这里说的js跨域是指通过js在不同的域之间进行数据传输或通信,比如用ajax向一个不同的域请求数据,或者通过js获取页面中不同 ...
- 节点地址的函数list_entry()原理详解
本节中,我们继续讲解,在linux2.4内核下,如果通过一些列函数从路径名找到目标节点. 3.3.1)接下来查看chached_lookup()的代码(namei.c) [path_walk()> ...
- WebActivator的实现原理详解
WebActivator的实现原理详解 文章内容 上篇文章,我们分析如何动态注册HttpModule的实现,本篇我们来分析一下通过上篇代码原理实现的WebActivator类库,WebActivato ...
随机推荐
- linux中安装jdk+jmeter-
--------------linux中安装jdk+jmeter-------------------- 一.安装JDK7.0版本 .先卸载服务器自带的jdk软件包 # java -version # ...
- Ansible简单介绍(一)
一 :ansible简单介绍 此名取自 Ansible 作者最喜爱的<安德的游戏> 小说,而这部小说更被后人改编成电影 -<战争游戏>. 官网地址:https://www.an ...
- bootstrap和JS实现下拉菜单
// bootstrap下拉菜单 <div class="btn-group"> <button id="button_text" type= ...
- CNN for NLP
卷积神经网络在自然语言处理任务中的应用.参考链接:Understanding Convolutional Neural Networks for NLP(2015.11) Instead of ima ...
- Codeforces 1082 毛毛虫图构造&最大权闭合子图
A #include<bits/stdc++.h> using namespace std; typedef long long ll; , MAXM = ; //int to[MAXM ...
- pycharm中如何让两个项目并存
之前总是打开一个,另外一个没有了,来回切换还要找最近的project.十分麻烦. 1.File下拉项中选择Settings 2.Settings设置界面打开Project下拉列表,选择“Project ...
- Linux下mount存储盘遇到的错误
一.注意点 1.超过1T的盘,创建的分区要查看是否初始化为了GPT格式. 2.如果新添加的盘是从存储上挂载的,涉及到多路径的问题,挂载的是多路径的盘符,比如:/dev/mapper/mpatha(对应 ...
- Oracle之:Function :getcurrdate()
getdate()函数连接请戳这里 create or replace function getcurrdate(i_date date) return date is v_date date; v_ ...
- Acwing-167-木棒(搜索, 剪枝)
链接: https://www.acwing.com/problem/content/169/ 题意: 乔治拿来一组等长的木棒,将它们随机地砍断,使得每一节木棍的长度都不超过50个长度单位. 然后他又 ...
- Swagger2常用注解和使用方法
一 引入maven依赖 <!--整合Swagger2--> <dependency> <groupId>com.spring4all</groupId&g ...