Adam和学习率衰减(learning learning decay)
本文先介绍一般的梯度下降法是如何更新参数的,然后介绍 Adam 如何更新参数,以及 Adam 如何和学习率衰减结合。
梯度下降法更新参数
梯度下降法参数更新公式:
\[
\theta_{t+1} = \theta_{t} - \eta \cdot \nabla J(\theta_t)
\]
其中,\(\eta\) 是学习率,\(\theta_t\) 是第 \(t\) 轮的参数,\(J(\theta_t)\) 是损失函数,\(\nabla J(\theta_t)\) 是梯度。
在最简单的梯度下降法中,学习率 \(\eta\) 是常数,是一个需要实现设定好的超参数,在每轮参数更新中都不变,在一轮更新中各个参数的学习率也都一样。
为了表示简便,令 \(g_t = \nabla J(\theta_t)\),所以梯度下降法可以表示为:
\[
\theta_{t+1} = \theta_{t} - \eta \cdot g_t
\]
Adam 更新参数
Adam,全称 Adaptive Moment Estimation,是一种优化器,是梯度下降法的变种,用来更新神经网络的权重。
Adam 更新公式:
\[
\begin{aligned}
m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\
v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \\
\hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\
\hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \\
\theta_{t+1}&=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t}
\end{aligned}
\]
在 Adam 原论文以及一些深度学习框架中,默认值为 \(\eta = 0.001\),\(\beta_1 = 0.9\),\(\beta_2 = 0.999\),\(\epsilon = 1e-8\)。其中,\(\beta_1\) 和 \(\beta_2\) 都是接近 1 的数,\(\epsilon\) 是为了防止除以 0。\(g_{t}\) 表示梯度。
咋一看很复杂,接下一一分解:
- 前两行:
\[
\begin{aligned}
m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\
v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2}
\end{aligned}
\]
这是对梯度和梯度的平方进行滑动平均,即使得每次的更新都和历史值相关。
中间两行:
\[
\begin{aligned}
\hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\
\hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}}
\end{aligned}
\]
这是对初期滑动平均偏差较大的一个修正,叫做 bias correction,当 \(t\) 越来越大时,\(1-\beta_{1}^{t}\) 和 \(1-\beta_{2}^{t}\) 都趋近于 1,这时 bias correction 的任务也就完成了。最后一行:
\[
\theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t}
\]
这是参数更新公式。
学习率为 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\),每轮的学习率不再保持不变,在一轮中,每个参数的学习率也不一样了,这是因为 \(\eta\) 除以了每个参数 \(\frac{1}{1- \beta_2} = 1000\) 轮梯度均方和的平方根,即 \(\sqrt{\frac{1}{1000}\sum_{k = t-999}^{t}g_k^2}\)。而每个参数的梯度都是不同的,所以每个参数的学习率即使在同一轮也就不一样了。(可能会有疑问,\(t\) 前面没有 999 轮更新怎么办,那就有多少轮就算多少轮,这个时候还有 bias correction 在。)
而参数更新的方向也不只是当前轮的梯度 \(g_t\) 了,而是当前轮和过去共 \(\frac{1}{1- \beta_1} = 10\) 轮梯度的平均。
有关滑动平均的理解,可以参考我之前的博客:理解滑动平均(exponential moving average)。
Adam + 学习率衰减
在 StackOverflow 上有一个问题 Should we do learning rate decay for adam optimizer - Stack Overflow,我也想过这个问题,对 Adam 这些自适应学习率的方法,还应不应该进行 learning rate decay?
我简单的做了个实验,在 cifar-10 数据集上训练 LeNet-5 模型,一个采用学习率衰减 tf.keras.callbacks.ReduceLROnPlateau(patience=5),另一个不用。optimizer 为 Adam 并使用默认的参数,\(\eta = 0.001\)。结果如下:


加入学习率衰减和不加两种情况在 test 集合上的 accuracy 分别为: 0.5617 和 0.5476。(实验结果取了两次的平均,实验结果的偶然性还是有的)
通过上面的小实验,我们可以知道,学习率衰减还是有用的。(当然,这里的小实验仅能代表一小部分情况,想要说明学习率衰减百分之百有效果,得有理论上的证明。)
当然,在设置超参数时就可以调低 \(\eta\) 的值,使得不用学习率衰减也可以达到不错的效果。
将学习率从默认的 0.001 改成 0.0001,epoch 增大到 120,实验结果如下所示:


加入学习率衰减和不加两种情况在 test 集合上的 accuracy 分别为: 0.5636 和 0.5688。(三次实验平均,实验结果仍具有偶然性)
这个时候,使用学习率衰减带来的影响可能很小。
那么问题来了,Adam 做不做学习率衰减呢?
我个人会选择做学习率衰减。(仅供参考吧。)在初始学习率设置较大的时候,做学习率衰减比不做要好;而当初始学习率设置就比较小的时候,做学习率衰减似乎有点多余,但从 val set 上的效果看,做了学习率衰减还是可以有丁点提升的。
ReduceLROnPlateau 在 val_loss 正常下降的时候,对学习率是没有影响的,只有在 patience(默认为 10)个 epoch 内,val_loss 都不下降 1e-4 或者直接上升了,这个时候降低学习率确实是可以很明显提升模型训练效果的,在 val_acc 曲线上看到一个快速上升的过程。对于其它类型的学习率衰减,这里没有过多地介绍。
Adam 衰减的学习率
从上述学习率曲线来看,Adam 做学习率衰减,是对 \(\eta\) 进行,而不是对 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 进行,但有区别吗?
学习率衰减一般如下:
exponential_decay:
decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)natural_exp_decay:
decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / decay_steps)ReduceLROnPlateau
如果被监控的值(如‘val_loss’)在 patience 个 epoch 内都没有下降,那么学习率衰减,乘以一个 factor
decayed_learning_rate = learning_rate * factor
这些学习率衰减都是直接在原学习率上乘以一个 factor ,对 \(\eta\) 或对 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 操作,结果都是一样的。
References
An overview of gradient descent optimization algorithms -- Sebastian Ruder
Should we do learning rate decay for adam optimizer - Stack Overflow
Tensorflow中learning rate decay的奇技淫巧 -- Elevanth
Adam和学习率衰减(learning learning decay)的更多相关文章
- 权重衰减(weight decay)与学习率衰减(learning rate decay)
本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...
- 吴恩达深度学习笔记(五) —— 优化算法:Mini-Batch GD、Momentum、RMSprop、Adam、学习率衰减
主要内容: 一.Mini-Batch Gradient descent 二.Momentum 四.RMSprop 五.Adam 六.优化算法性能比较 七.学习率衰减 一.Mini-Batch Grad ...
- 梯度下降法】三:学习率衰减因子(decay)的原理与Python
http://www.41443.com/HTML/Python/20161027/512492.html
- ubuntu之路——day8.5 学习率衰减learning rate decay
在mini-batch梯度下降法中,我们曾经说过因为分割了baby batch,所以迭代是有波动而且不能够精确收敛于最小值的 因此如果我们将学习率α逐渐变小,就可以使得在学习率α较大的时候加快模型训练 ...
- 跟我学算法-吴恩达老师(mini-batchsize,指数加权平均,Momentum 梯度下降法,RMS prop, Adam 优化算法, Learning rate decay)
1.mini-batch size 表示每次都只筛选一部分作为训练的样本,进行训练,遍历一次样本的次数为(样本数/单次样本数目) 当mini-batch size 的数量通常介于1,m 之间 当 ...
- TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)
在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...
- 改善深层神经网络_优化算法_mini-batch梯度下降、指数加权平均、动量梯度下降、RMSprop、Adam优化、学习率衰减
1.mini-batch梯度下降 在前面学习向量化时,知道了可以将训练样本横向堆叠,形成一个输入矩阵和对应的输出矩阵: 当数据量不是太大时,这样做当然会充分利用向量化的优点,一次训练中就可以将所有训练 ...
- pytorch learning rate decay
关于learning rate decay的问题,pytorch 0.2以上的版本已经提供了torch.optim.lr_scheduler的一些函数来解决这个问题. 我在迭代的时候使用的是下面的方法 ...
- [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)
一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...
随机推荐
- PDW中的Split Querying Process
最近看了关于 SQL Server 的分布式处理方面的论文,觉得它提出的 Polybase 跟之前看过的 HadoopDB 有些神似,这里做个小总结(抽空再把 HadoopDB 的总结贴出来). 不算 ...
- 【cocos2d-js官方文档】五、Cocos2d-JS v3.0的新Action API
新增action中的方法 曾经,当我们须要反复一个action的时候,我们须要: sprite.runAction(cc.Repeat.create(action, 2)); 上面代码中创建了一个新的 ...
- PHP采集类:Snoopy.class.php
Snoopy是一个php采集类,用来模拟浏览器获取网页内容和发送表单. 下面是一些Snoopy特性: 容易抓取网页内容 容易抓取页面文本(去除HTML标签) 容易抓取网页内链接 支持代理抓取 支持基本 ...
- 对于ado.net dataProvider的介绍
学习刘皓的 ADO.NET入门教程(二)了解.NET数据提供程序 而来 这篇文章很一般,主要是对dataProvider做了个简单的介绍.因为在该系列文章中提到,ado.net主要有两部分 dataP ...
- PowerDesigner模型分类
原文:PowerDesigner模型分类 版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/zjws23786/article/details/8005 ...
- 狄利克雷过程(Dirichlet Process)
0. 引入 现观察得到两个样本 θ1,θ2,来推测它们可能来自的分布: 假设来自于连续型概率密度函数, θ1,θ2∼H(θ) 则 θ1,θ2 相等的概率为 0,p(θ1=θ2)=0 概率为 0,不代表 ...
- springboot 上传图片
1. 创建多层目录 创建多层目录要使用File的mkdirs()方法,其他可以使用mkdir()方法. 2. 文件大小限制 配置文件中,在spring1.4以后要使用 ` spring.http.mu ...
- C#彩色艺术化二维码样式设计(仅说思路)
原文:C#彩色艺术化二维码样式设计(仅说思路) 仅讲思路,想要源码的请绕道. 一.样式 1.先看各种二维码的样式吧: (1)最简单的样式--黑白样式,如下图: 图1 最平常见到的二维码样式(如果 ...
- NoSQL Manager for Cassandra 3.2.0.1 带Key
NoSQL Manager for Cassandra 3.2.0.1 是一个Windows平台下Cassandra 数据库的高级管理工具.请低调使用. NoSQLManagerforCassandr ...
- MyBatis 问题 & 解决
# 问题 Invalid bound statement (not found) # 解决 <mappers> 标签的包括的是 SQL 语句存在的地方,此外 <mapper> ...