Last we learned Recurrent Neural Netwoks (RNN) and why they'er great for Language Modeling (LM) 就之前整理 RNN 作为语言模型的神经网络, 理解上还是可以, 关键点在于 W 的复用, 和 上一个状态的输出, 作为下一状态的输入. 如果对语言模型也稍微了解的话, 对于 RNN 就能很自然过渡, 网络的逻辑, 也并不复杂.

这里呢就想来讨论下, 最基础的 RNN 所待解决的问题, 比如标题所谈的梯度消失, 或者梯度爆炸. 然后如何去 fix them. 然后再引入一些 More complex RNN variants (其他的 RNN 变体 如 LSTM, GRU) 等, RNN 我感觉在应用上, 还是蛮不错的.

然后呢就是关于Vanishing Gradient Problem (梯度消失) 从而引出 LSTM 和 GRU ... 还有各种变体, 如 Bidirectional -RNN; Multi - layer - RNN...

梯度消失与爆炸

这是神经网络都可能会存在问题, 因为训练大多基于BP 算法的, 从数学上看就是 多元函数求偏导, 以及求导过程中应用链式法则. 中间就是很多项相乘嘛, 如果都是 非常小的数相乘, 那整个结果就接近 0 了呀, 没有梯度了.

就像这张图所表示的那样, 如果中间的项很小...那整体的结果, 就造成了梯度消失的问题 (Vanishing gradient) .

Vanishing gradient proof sketch

直接从隐含层来看出端倪.

\(h^{(t)} = \sigma(W_hh^{(t-1)} + W_xx^{(t)} + b_t )\)

隐含层取决于, 给定的输入 x 先onehot 在 embedding, 再和复用的权值矩阵 W 相乘 and 上一个时间点的 W_h 和 h 的乘积. 再作为激活函数的输入. 最后到一个值都在[0, 1] 之间的向量.

然后对 \(h^{(t-1)}\) 来求导 (链式法则哦)

$\frac {\partial h^{(t)} {1}} {\partial h^{(t-1)}} = $ \(diag \ (\sigma ' (W_hh^{(t-1)} + W_xx^{(t)} + b_t ) )W_h\)

  • 链式法则而已. 相当于是 y = h(z), z = ax; 要求 y 对 x 的偏导, 即: h(z)' * 偏z 对 偏 x 的值 ....
  • 对于 sigmoid 函数(简记 \(\sigma(x)\) 求导的结果是, \(\sigma(x) [1-\sigma(x)]\)

这个 latex 写得让人头疼... 贴个图算了, 尤其是这种复杂的上下标啥的.

当 \(W_h\) 非常小的时候, 它的 (i-j) 次方, 这个值就会变得非常小了呀. 注意 W_h 是个矩阵啥. 我们通过说的 矩阵的小, 指其 模 非常小 \(|W_h|\) 或者是说, 对这个矩阵 进行 特征分解 (eigenvalue, eigenvector) , 它最大的特征值 的绝对值 如果最大的特征值, 小于 1 则 \(||W_h||\) 这个行列式的值会变很小. (前人已经证明了, 我也没懂, 就先记一个结论来用着) 如果最大的特征值大于1, 则可能会带来梯度爆炸的问题 (exploding gradients).

Why is vanishing gradient a problem

解释一: 从导数的意义上.

梯度消失的表现, 如下图表示的那样, 回归到 导数的意义, 用来衡量 "变化率" \(\frac {dy} {dt}\)

出现梯度为零, 则表示, 对于 h 的一个微小增量, 而 j 并未受到啥影响. 从图上来看, 就是隔得太远, 如 j(4) 基本不会受到 h(1) 的影响了呗. 或者详细一点可以这样说:

Gradient signal from faraway is lost because it's much smaller than gradient signal from close-by

So model weights are only updated only with respect to near effects, not long-term effects.

**解释二: **

Gradient can be viewed ans measure (测量) of the effect of the past on the future. 字面意思就是, 梯度, 可以看成是, 未来对现在的衡量. 梯度小, 则表示未来对现在的影响小.

未来影响现在?? 我感觉这个时间线, 似乎不太理解哦

总之哈, 梯度很小所反映的基本事实是:

  • 在第 t step 和 t + n step , 如果 n 比较大, 则 这两个状态的 单词的 "关联度" 比较小
  • 因而我们所计算出来的参数就不正确了哦.

Why is exploding gradient a problem

同样的, 梯度爆炸, 也是一个大问题. (从用梯度下降法来更新参数, 能能直观看出)

If the gradient becomes too big, then the SGD (随机梯度下降法) update step become too big.

\(\theta^{new} = \theta^{old} - \alpha \nabla _\theta J(\theta)\)

这一块, 基本了解 ML 的都贼熟悉哈. 本质上是对参数向量的一个调整嘛, 当梯度 \(\nabla _\theta J(\theta)\) 特别大的时候, 然后整个参数被都这波节奏给带崩了.

This can cause bad updates : we take too large a step and reach a bad paramenter configuration (with large loss)

从代码运行角度看,

In the worst case (更加糟糕的是) , this will result in inf or NaN in your network. Then you have to restart training from an earlier checkpoint. 就是代码报错, 要重写搞, 一重写运行, 几个小时又过去了...这也是我不太想学深度学习的原因之一.

解决 - 梯度消失和爆炸

solve vanishing

先 pass 下

solve exploding

有种方法叫做, Gradient clipping: If the norm of the gradient is greater than some threshhold (阀值), scale it down before applying SGD update.

如下图所示, 对参数向量 g 取模, 如果它大于某个阀值, 就 将其更新为 (下图) 相当于对 g 进行了一个缩放 (变小了)

向量缩放的特点是, 没有改变其原来的方向, SGD 中, 就还是沿着 梯度方向来调整参数 哦.

即, take a step in the same direction, but a small step. 有点东西哦.

小结

  • 熟练 RNN 的网络结构和特性, 如 W 复用, 输出 -> 输入
  • 梯度消失, BP的参数训练, 求导的链式法则, 可能会有项直接乘积非常小, 整个式子没有梯度, 表 词间的关联性弱
  • 梯度爆炸, 也是在参数更新这块, 调整步伐太大, 产生 NaN 或 Inf, 代码就搞崩了直接
  • 解决梯度消失...
  • 解决梯度爆炸, 可以采用 clipping 的方式, 对向量进行缩放, 而不改变其方向.

RNN - 梯度消失与爆炸的更多相关文章

  1. RNN梯度消失和爆炸的原因 以及 LSTM如何解决梯度消失问题

    RNN梯度消失和爆炸的原因 经典的RNN结构如下图所示: 假设我们的时间序列只有三段,  为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下: 假设在t=3时刻,损失函数为  . 则对于一 ...

  2. LSTM如何解决梯度消失或爆炸的?

    from:https://zhuanlan.zhihu.com/p/44163528 哪些问题? 梯度消失会导致我们的神经网络中前面层的网络权重无法得到更新,也就停止了学习. 梯度爆炸会使得学习不稳定 ...

  3. 讨论LSTM和RNN梯度消失问题

      1RNN为什么会有梯度消失问题 (1)沿时间反向方向:t-n时刻梯度=t时刻梯度* π(W*激活函数的导数)  

  4. [ DLPytorch ] 循环神经网络进阶&拟合问题&梯度消失与爆炸

    循环神经网络进阶 BPTT 反向传播过程中,训练模型通常需要模型参数的梯度. \[ \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t= ...

  5. 梯度消失&&梯度爆炸

    转载自: https://blog.csdn.net/qq_25737169/article/details/78847691 前言 本文主要深入介绍深度学习中的梯度消失和梯度爆炸的问题以及解决方案. ...

  6. 神经网络优化算法:Dropout、梯度消失/爆炸、Adam优化算法,一篇就够了!

    1. 训练误差和泛化误差 机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不⼀定更准确.这是为什么呢 ...

  7. RNN神经网络产生梯度消失和梯度爆炸的原因及解决方案

    1.RNN模型结构 循环神经网络RNN(Recurrent Neural Network)会记忆之前的信息,并利用之前的信息影响后面结点的输出.也就是说,循环神经网络的隐藏层之间的结点是有连接的,隐藏 ...

  8. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  9. 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸

    网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...

  10. Backpropagation Through Time (BPTT) 梯度消失与梯度爆炸

    Backpropagation Through Time (BPTT) 梯度消失与梯度爆炸 下面的图显示的是RNN的结果以及数据前向流动方向 假设有 \[ \begin{split} h_t & ...

随机推荐

  1. 并发编程 - 线程同步(九)之信号量Semaphore

    前面对自旋锁SpinLock进行了详细学习,今天我们将学习另一个种同步机制--信号量Semaphore. 01.信号量是什么? 在 C# 中,信号量(Semaphore)是一种用于线程同步的机制,能够 ...

  2. autMan奥特曼机器人-代理池配置教程

    一.优势: 全可视化 稳如老牛(从2.8.6开始) 隧道代理和接口获取,使用灵活 代理池运行状态指令可查:代理池 二.启用代理池并设置服务端口 代理池的启用与关闭,均为重启autMan生效 设置隧道代 ...

  3. NetPad:一个.NET开源、跨平台的C#编辑器

    前言 今天大姚给大家分享一个基于.NET开源.跨平台的C#编辑器和游乐场:NetPad. 项目介绍 NetPad是一个基于.NET开源(MIT License).跨平台的C#编辑器和游乐场,它允许用户 ...

  4. Springboot - [05] 彩蛋~

    题记部分 彩蛋一:如何更换Springboot启动时的logo (1)访问 https://www.bootschool.net/ascii-art/search,搜索到佛祖的ASCII艺术字(图)集 ...

  5. Task VS ValueTask

    在 C# 中,异步编程是构建响应式应用程序的基础.Task 是表示异步操作的首选类型.但是,在某些高性能场景中,与 Task 相关的开销可能会达到一个瓶颈.ValueTask 是 .NET Core ...

  6. Elasticsearch搜索引擎学习笔记(四)

    分词器 内置分词器 standard:默认分词,单词会被拆分,大小会转换为小写. simple:按照非字母分词.大写转为小写. whitespace:按照空格分词.忽略大小写. stop:去除无意义单 ...

  7. 青岛oj集训1

    2025/3/4 内容:有向无环图(DAG) 优点:DAG有很多良好性质 拓扑排序 用处:可以根据拓扑序进行dp 这次计算所用的所有边的权值都是有计算过的 一张DAG图肯定有拓扑序(bfs序,dfs序 ...

  8. Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露

    一:背景 1. 讲故事 前面跟大家分享过一篇 C# 调用 C代码引发非托管内存泄露 的文章,这是一个故意引发的正向泄露,这一篇我们从逆向的角度去洞察引发泄露的祸根代码,这东西如果在 windows 上 ...

  9. 【基础知识笔记】004 matlab-矩阵和数组的关系

    之前以为是两种东西,今天看了mathworks的官网才知道 所有 MATLAB 量都是多维数组,与数据类型无关.矩阵是指通常用来进行线性代数运算的二维数组 1.数组创建 要创建每行包含四个元素的数组, ...

  10. TPC-H 研究和优化尝试

    TPC-H测试提供了8张表,最近做这个测试,记录下过程中的关键点备忘. 1.整体理解TPC-H 8张表 2.建立主外键约束后测试22条SQL 3.分区表改造,确认分区字段 4.重新测试22条SQL 5 ...