Gradient Optimization
Gradient Optimization
Gradient Descent
- Batch Gradient Descent
- Mini-Batch Gradient Descent
- Stochastic Gradient Descent
Mini-Batch Gradient Descent
- 参数
- Mini-Batch Size: 一个Batch样本所含的样本数
- 参数效果
- 通过设置Mini-Batch Size可以将Mini-Batch转为Stochastic Gradient Descent和Bath Gradient Descent
- 当Mini-Batch Size == m时, Mini-Batch Gradient Descent为Batch Gradient Descent; 当Mini-Batch Size == 1时, Mini-Batch Gradient Descent为Stochastic Gradient Descent; 一般Mini-Batch Size的大小为2的幂次方, 主要考虑到与计算机内存对齐, 一般Mini-Batch Size设置在64-512
- 特点对比
- Batch Gradient Descent: 当数据量大的时候程序运行很慢
- Stochastic Gradient Descent: 一般不采用此方法, 因为此Gradient Descent方法每迭代一个样本就会更新参数\(W\)和\(b\), 在梯度下降的时候有很多噪音; 但是它可以应用到在线学习上
- Mini-Batch Gradient Descent: 当数据量较大的时候可以加快收敛速度, 但是当在梯度下降的时候, 容易产生震荡(oscillate), 如图
- 其中, +表示\(J_{min}\), 在使用Mini-Batch Gradient Descent的时候, 容易在数值方向产生震荡, 我们期望的是缩小竖直方向上的震荡, 在水平方向上加快收敛的速率, 对于这个问题, 解决方案是在Update Parameters的时候, 采用Momentum, RMSProp或者Adam的方法更新参数\(W\)和\(b\), 在下面就会提到
处理震荡
- 指数权重均值(Exponentially Weighted Average, 简称EMA), 后面的Momentum, RMSProp和Adam都需要EMA
- 以一年的中所有天数的温度为例, 如图
- 由上图可知, \(\theta\)为气温, \(t\)为天数, 总体来说中间时刻气温低一点, 两侧高一点
- 定义
- \(v_t=\beta{v_{t-1}}+(1-\beta)\theta_t\), 其中\(\beta\)为EMA中的一个参数, 一般他的取值范围在\(0.9\le \beta \le 0.99\); \(v_t\)表示的就是EMA; \(\theta_t\)为第\(t\)天的气温; \(\beta{v_t}\)表示的是前\(t\)天的关注度, 后面的\((1-\beta)\theta_t\)是对当前天气温的关注度, 最左侧的\(v_t\)才是我们对当前天的EMA; EMA公式有递归的感觉
Momentum
- 公式
- \(v_{dW^{[l]}}=\beta{v_{dW^{[l]}}}+(1-\beta)dW^{[l]}\), 其中, \(dW^{[l]}\)是第\(l\)层的梯度矩阵, 其他与EMA中的是一样的
- 返现与EMA中不同的是这里的\(\beta{v_{dW^{[l]}}}\)不是\(\beta{v_{dW^{[l-1]}}}\), 因为我们在实现该算法的时候采用先默认赋予0值, 再在每一次迭代时累加, 下面的RMSProp和Adam也是如此
- 更新参数
- \(W^{[l]}=W^{[l]}-\alpha{v_{dW^{[l]}}}\)
- 公式
RMSProp
- 公式
- \(s_{dW^{[l]}}=\beta{v_{dW^{[l]}}}+(1-\beta)(dW^{[l]})^2\), 其中, 与Momentum不同的就是此处\((dW^{[l]})^2\)
- 更新参数
- \(W^{[l]}=W^{[l]}-\alpha{dW^{[l]}\over{\sqrt{s_{dW^{[l]}}+\epsilon}}}\), 其中\(\epsilon \approx 10^{-8}\)
- 公式
Adam
- Adam算法是Momentum与RMSProp的结合
- 公式
- \(v_{dW^{[l]}}=\beta_1{v_{dW^{[l]}}}+(1-\beta_1)dW^{[l]}\)
- \(v^{correct}_{dW^{[l]}}={v_{dW^{[l]}}\over{1-(\beta_1)^t}}\), 其中t表示深度学习算法迭代到第t次, 这一步是\(v_{dW^{[l]}}\)的修正, 在后面即使使用\(v^{correct}_{dW^{[l]}}\)
- \(s_{dW^{[l]}}=\beta_2{v_dW^{[l]}}+(1-\beta_1)(dW^{[l]})^2\)
- \(s^{correct}_{dW^{[l]}}={s_{dW^{[l]}}}\over{1-(\beta_2)^t}\),其中t表示深度学习算法迭代到第t次, 这一步是\(s_{dW^{[l]}}\)的修正, 在后面即使使用\(s^{correct}_{dW^{[l]}}\)
- Adam结合了之前的Momentum与RMSProp算法, 同时增加了校正EMA的步骤, 因为在Momentum和RMSProp算法都有\(\beta\)和\(s\), 所有在这里为了区分, 使用了\(v\)与\(s\), \(\beta_1\)与\(\beta_2\)
- 更新参数
- \(W^{[l]}=W^{[l]}-\alpha{v_{dW^{[l]}}^{[l]}\over{\sqrt{s_{dW^{[l]}}+\epsilon}}}\), 其中\(\epsilon \approx 10^{-8}\)
使用代码实现的大致思路
- 选择Mini-Batch Gradient Descent
- Shuffle原始数据
- 选择Mini-Batch Size进行Gradient Descent
- 在迭代Update Parameters时, 先为Momentum, RMSProp或者Adam需要的\(v\), \(s\)变量赋予0值, 维度与对应的\(dW\)一致
- 迭代即可
学习率\(\alpha\)的衰减
- 一般来说我们只需要直接固定\(\alpha\)的值, 随后根据结果进行调整, 但是在数据量很大的时候就会比较浪费时间, 于是使用到了\(alpha\)的衰减
- 定义
- \(\alpha={1\over{1+decay\_rate\times{epoch}}}\alpha_0\)
Gradient Optimization的更多相关文章
- ( 转) Awesome Image Captioning
Awesome Image Captioning 2018-12-03 19:19:56 From: https://github.com/zhjohnchan/awesome-image-capti ...
- ICCV 2017论文分析(文本分析)标题词频分析 这算不算大数据 第一步:数据清洗(删除作者和无用的页码)
IEEE International Conference on Computer Vision, ICCV 2017, Venice, Italy, October 22-29, 2017. IEE ...
- 近年Recsys论文
2015年~2017年SIGIR,SIGKDD,ICML三大会议的Recsys论文: [转载请注明出处:https://www.cnblogs.com/shenxiaolin/p/8321722.ht ...
- SciPy和Numpy处理能力
1.SciPy和Numpy的处理能力: numpy的处理能力包括: a powerful N-dimensional array object N维数组: advanced array slicing ...
- [CS231n-CNN] Training Neural Networks Part 1 : activation functions, weight initialization, gradient flow, batch normalization | babysitting the learning process, hyperparameter optimization
课程主页:http://cs231n.stanford.edu/ Introduction to neural networks -Training Neural Network ________ ...
- (转) An overview of gradient descent optimization algorithms
An overview of gradient descent optimization algorithms Table of contents: Gradient descent variants ...
- An overview of gradient descent optimization algorithms
原文地址:An overview of gradient descent optimization algorithms An overview of gradient descent optimiz ...
- 【论文翻译】An overiview of gradient descent optimization algorithms
这篇论文最早是一篇2016年1月16日发表在Sebastian Ruder的博客.本文主要工作是对这篇论文与李宏毅课程相关的核心部分进行翻译. 论文全文翻译: An overview of gradi ...
- [CS231n-CNN] Linear classification II, Higher-level representations, image features, Optimization, stochastic gradient descent
课程主页:http://cs231n.stanford.edu/ loss function: -Multiclass SVM loss: 表示实际应该属于的类别的score.因此,可以发现,如果实际 ...
随机推荐
- Linux常用命令,学的时候自己记的常用的保存下来方便以后使用 o(∩_∩)o 哈哈
service httpd restart 重启Apache service mysqld restart 重启mysql [-][rwx][r-x][r--] 1 234 567 890 421 4 ...
- VMware下拷过来的Linux系统ifconfig下没有网络问题
拷了同事的Linux系统,拷过来时还可以用,今天再打开发现找不到ip了,于是就在网上找解决方法,因本人从没接触过Linux所以查的挺多的但解决的方法试了好几个就是不行,后面找到的有效的解决方法有: L ...
- php swoole扩展安装
一波三折. 首先下载swoole安装包(由于我这里php是7,所以说应该去官网下载最新的swoole包,否则会发生意想不到的错误) wget https://github.com/swoole/swo ...
- 支付机构MRC模
一.电商RFM模型 RFM模型是一个简单的根据客户的活跃程度和交易金额贡献所做的分类.因为操作简单,所以较为常用. 近度R:R代表客户最近的活跃时间距离数据采集点的时间距离,R越大,表示客户越久未发生 ...
- Thinkpad R400无线网络一个都不见了!
今天,我的Thinkpad R400无线网络一个都不见了,用手机测试,家里的无线网络正常.反复开关无线网络控制硬件开关,还是不好使,无线网络那个图标里面,并没有无线网络.上网搜索一下,发现小黑居然还有 ...
- “全栈2019”Java第八十二章:嵌套接口能否访问外部类中的成员?
难度 初级 学习时间 10分钟 适合人群 零基础 开发语言 Java 开发环境 JDK v11 IntelliJ IDEA v2018.3 文章原文链接 "全栈2019"Java第 ...
- 洛谷P4705 玩游戏(生成函数+多项式运算)
题面 传送门 题解 妈呀这辣鸡题目调了我整整三天--最后发现竟然是因为分治\(NTT\)之后的多项式长度不是\(2\)的幂导致把多项式的值存下来的时候发生了一些玄学错误--玄学到了我\(WA\)的点全 ...
- scrapy框架基于CrawlSpider的全站数据爬取
引入 提问:如果想要通过爬虫程序去爬取”糗百“全站数据新闻数据的话,有几种实现方法? 方法一:基于Scrapy框架中的Spider的递归爬取进行实现(Request模块递归回调parse方法). 方法 ...
- iOS 键盘的监听 调整view的位置
iOS在处理键盘的出现和消失时需要监听UIKeyboardWillChangeFrameNotifications/UIKeyboardDidHideNotifications - (void)vie ...
- myeclipse部署项目到tomcat-custom_location 方式
在想要部署的路径下:1.新建一个在tomcat--->server.xml文件夹下设置的文件名 2.在新建的文件夹下新建一个 ROOT文件夹, 3.在myeclipse里面吧项目部署到 ROO ...