梯度优化算法Adam
最近读一个代码发现用了一个梯度更新方法, 刚开始还以为是什么奇奇怪怪的梯度下降法, 最后分析一下是用一阶梯度及其二次幂做的梯度更新。网上搜了一下, 果然就是称为Adam的梯度更新算法, 全称是:自适应矩估计(adaptive moment estimation)
国际惯例, 参考博文:
一文看懂各种神经网络优化算法:从梯度下降到Adam方法
Adam:一种随机优化方法
An overview of gradient descent optimization algorithms
梯度下降优化算法综述
Hinton的神经网络课程第六课
理论
由于参考博客介绍的很清晰, 我就直接撸公式了:
假设tt时刻, 目标函数对于参数的一阶导数是gtgt,那么我们可以先计算
mtvt=β1mt−1+(1−β1)gt=β2vt−1+(1−β2)g2t
mt=β1mt−1+(1−β1)gtvt=β2vt−1+(1−β2)gt2
接下来计算
mt^=mt1−βt1vt^=vt1−βt2
mt^=mt1−β1tvt^=vt1−β2t
最后我们的梯度更新方法就是
θt+1=θt−η⋅mt^vt^−−√+ϵ
θt+1=θt−η⋅mt^vt^+ϵ
注意几个量, ηη是学习步长, 剩下的三个参数取值的建议是β1=0.9,β2=0.999,ϵ=10−8β1=0.9,β2=0.999,ϵ=10−8, 分母中的ϵϵ是为了防止除零. 其实这个步长的话,一般来说是建议选η=0.001η=0.001之类的, 注意βt1,βt2β1t,β2t中的tt是参与指数运算的
其实再看一下公式,其实就是当前时刻的梯度更新利用了上一时刻的平方梯度vtvt的指数衰减均值vt^vt^和上一时刻的梯度mtmt的指数衰减均值mt^mt^
代码实现
以下非一个神经网络的完整实现, 主要在于看看定义网络参数以后怎么去使用Adam去更新每一时刻的梯度, 在theano中的实现方法如下:
先看看神经网络的参数
self.layers = [
self.W0, self.W1, self.W2,
self.b0, self.b1, self.b2]
self.params = sum([layer.params for layer in self.layers], [])
1
2
3
4
5
然后初始化一开始时候的mt,vtmt,vt,分别对应代码中的m0params,m1paramsm0params,m1params
self.params = network.params
self.m0params = [theano.shared(np.zeros(p.shape.eval(), dtype=theano.config.floatX), borrow=True) for p in self.params]
self.m1params = [theano.shared(np.zeros(p.shape.eval(), dtype=theano.config.floatX), borrow=True) for p in self.params]
self.t = theano.shared(np.array([1], dtype=theano.config.floatX))
1
2
3
4
定义目标函数=损失函数+正则项:
cost = self.cost(network, input, output) + network.cost(input)
1
计算当前梯度
gparams = T.grad(cost, self.params)
1
计算m0params,m1paramsm0params,m1params
m0params = [self.beta1 * m0p + (1-self.beta1) * gp for m0p, gp in zip(self.m0params, gparams)]
m1params = [self.beta2 * m1p + (1-self.beta2) * (gp*gp) for m1p, gp in zip(self.m1params, gparams)]
1
2
使用Adam梯度更新
params = [p - self.alpha *
((m0p/(1-(self.beta1**self.t[0]))) /
(T.sqrt(m1p/(1-(self.beta2**self.t[0]))) + self.eps))
for p, m0p, m1p in zip(self.params, m0params, m1params)]
1
2
3
4
然后更新下一时刻网络中的梯度值,m0paramsm0params,m1paramsm1params,tt
updates = ([( p, pn) for p, pn in zip(self.params, params)] +
[(m0, m0n) for m0, m0n in zip(self.m0params, m0params)] +
[(m1, m1n) for m1, m1n in zip(self.m1params, m1params)] +
[(self.t, self.t+1)])
---------------------
作者:风翼冰舟
来源:CSDN
原文:https://blog.csdn.net/zb1165048017/article/details/78392623
版权声明:本文为博主原创文章,转载请附上博文链接!
梯度优化算法Adam的更多相关文章
- 梯度优化算法总结以及solver及train.prototxt中相关参数解释
参考链接:http://sebastianruder.com/optimizing-gradient-descent/ 如果熟悉英文的话,强烈推荐阅读原文,毕竟翻译过程中因为个人理解有限,可能会有谬误 ...
- 深度学习必备:随机梯度下降(SGD)优化算法及可视化
补充在前:实际上在我使用LSTM为流量基线建模时候,发现有效的激活函数是elu.relu.linear.prelu.leaky_relu.softplus,对应的梯度算法是adam.mom.rmspr ...
- zz:一个框架看懂优化算法之异同 SGD/AdaGrad/Adam
首先定义:待优化参数: ,目标函数: ,初始学习率 . 而后,开始进行迭代优化.在每个epoch : 计算目标函数关于当前参数的梯度: 根据历史梯度计算一阶动量和二阶动量:, 计算当前时刻的下降 ...
- Adam那么棒,为什么还对SGD念念不忘 (3)—— 优化算法的选择与使用策略
在前面两篇文章中,我们用一个框架梳理了各大优化算法,并且指出了以Adam为代表的自适应学习率优化算法可能存在的问题.那么,在实践中我们应该如何选择呢? 本文介绍Adam+SGD的组合策略,以及一些比较 ...
- DeepLearning.ai学习笔记(二)改善深层神经网络:超参数调试、正则化以及优化--Week2优化算法
1. Mini-batch梯度下降法 介绍 假设我们的数据量非常多,达到了500万以上,那么此时如果按照传统的梯度下降算法,那么训练模型所花费的时间将非常巨大,所以我们对数据做如下处理: 如图所示,我 ...
- Coursera Deep Learning笔记 改善深层神经网络:优化算法
笔记:Andrew Ng's Deeping Learning视频 摘抄:https://xienaoban.github.io/posts/58457.html 本章介绍了优化算法,让神经网络运行的 ...
- 数值优化(Numerical Optimization)学习系列-无梯度优化(Derivative-Free Optimization)
数值优化(Numerical Optimization)学习系列-无梯度优化(Derivative-Free Optimization) 2015年12月27日 18:51:19 下一步 阅读数 43 ...
- 改善深层神经网络_优化算法_mini-batch梯度下降、指数加权平均、动量梯度下降、RMSprop、Adam优化、学习率衰减
1.mini-batch梯度下降 在前面学习向量化时,知道了可以将训练样本横向堆叠,形成一个输入矩阵和对应的输出矩阵: 当数据量不是太大时,这样做当然会充分利用向量化的优点,一次训练中就可以将所有训练 ...
- 跟我学算法-吴恩达老师(mini-batchsize,指数加权平均,Momentum 梯度下降法,RMS prop, Adam 优化算法, Learning rate decay)
1.mini-batch size 表示每次都只筛选一部分作为训练的样本,进行训练,遍历一次样本的次数为(样本数/单次样本数目) 当mini-batch size 的数量通常介于1,m 之间 当 ...
随机推荐
- java-继承进阶_抽象类_接口
概要图 一, 继承的进阶 1.1,成员变量 重点明确原理. 特殊情况: 子父类中定义了一模一样的成员变量. 都存在于子类对象中. 如何在子类中直接访问同名的父类中的变量呢? 通过关键字 super来完 ...
- angular可以做的小功能 未完成
1,网上购物满多少减多少 思路: 效果图,满500减10元邮费 1,html部分有基本布局, <div> <h3>化妆品</h3> 单价 <input typ ...
- 从0开始学习 GitHub 系列之「01.初识 GitHub
转载:http://blog.csdn.net/googdev/article/details/52787516 1. 写在前面 我一直认为 GitHub 是程序员必备技能,程序员应该没有不知道 Gi ...
- 【JZOJ4709】【NOIP2016提高A组模拟8.17】Matrix
题目描述 输入 输出 样例输入 4 3 5 4 1 7 3 4 7 4 8 样例输出 59716 数据范围 解法 40%暴力即可: 60%依然暴力: 100%依次计算第一行和第一列对答案的贡献即可: ...
- Python学习笔记(三)字符串类型及其操作(2)
1.字符串的表示 字符串是字符的序列表示,可以由一对单引号(‘).双引号(“)或三引号(’‘’)构成.其中,单引号.双引号和三引号都可以表示单行字符串,但是只有三引号可以表示多行字符串 在使用双引号时 ...
- 完整版unity安卓发布流程(包括SDK有原生系统依赖关系的工程)
要3个东西!NDS,SDK,JDK, NDK官网下载:https://developer.android.google.cn/ndk/downloads/index.html(注意系统是不是64位) ...
- JDK的KEYTOOL的应用,以及签署文件的应用(原创)
首先,我是这样的情况下学到这部分知识的: 我们公司同事把自己的unity生成的APK包查出MD5值直接拿出去微信那边申请,当然这样本来是没毛病,毕竟当时只有他一个人开发这个游戏, 然而我们几个前端过去 ...
- StringUtils常用方式留存
StringUtils是org.apache.commons.lang下的一个工具包.主要用途从名字可以看出是针对于String的一些操作工具,里面包含的方法非常多,英语水平尚可以的人可以前往它的官方 ...
- laravel后台扩展包
https://github.com/the-control-group/voyager
- (续)使用Django搭建一个完整的项目(Centos7+Nginx)
django-admin startproject web cd web 2.配置数据库(使用Mysql) vim web/settings.py #找到以下并按照实际情况修改 DATABASES = ...