1.优化器算法简述

首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。

2.Batch Gradient Descent (BGD)

梯度更新规则:

BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:

 

缺点:

由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。

for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad

我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。

Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。

3.Stochastic Gradient Descent (SGD)

梯度更新规则:

和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad

看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。

缺点:

SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。

BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。

当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。

4.Mini-Batch Gradient Descent (MBGD)

梯度更新规则:

MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。

for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad

超参数设定值:  n 一般取值在 50~256

缺点:(两大缺点)

  1. 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
  2. SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。

5.Adam:Adaptive Moment Estimation

这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum

除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8

实践表明,Adam 比其他适应性学习方法效果要好。

参考文献:

https://www.cnblogs.com/guoyaohua/p/8542554.html

Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章

  1. 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  2. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  3. 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集

    机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...

  4. Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]

     Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...

  5. (转)live555学习笔记10-h264 RTP传输详解(2)

    参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...

  6. 大数据学习笔记——Spark工作机制以及API详解

    Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...

  7. 学习笔记--Grunt、安装、图文详解

    学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...

  8. Java8学习笔记(五)--Stream API详解[转]

    为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...

  9. 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解

    源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...

随机推荐

  1. Machine learning(4-Linear Regression with multiple variables )

    1.Multiple features So what the form of the hypothesis should be ? For convenience, define x0=1 At t ...

  2. 转载: XILINX GT的基本概念

    https://zhuanlan.zhihu.com/p/46052855 本来写了一篇关于高速收发器的初步调试方案的介绍,给出一些遇到问题时初步的调试建议.但是发现其中涉及到很多概念.逐一解释会导致 ...

  3. Nginx多种安装方式

    不指定参数配置的Nginx编译安装 ./configuremake make install wget下载或浏览器下载上传.解压进入目录[root@mcw1 nginx-1.10.2]# ls #查看 ...

  4. Nessus home版插件更新

    1,进入服务器停止服务 service nessusd stop 2,进入目录执行命令获取Challenge code cd /opt/nessus/sbin/ ./nessuscli fetch - ...

  5. Flink 实践教程 - 入门(4):读取 MySQL 数据写入到 ES

    ​作者:腾讯云流计算 Oceanus 团队 流计算 Oceanus 简介 流计算 Oceanus 是大数据产品生态体系的实时化分析利器,是基于 Apache Flink 构建的具备一站开发.无缝连接. ...

  6. 微信小程序小窗无效

    这里算是踩过一个坑吧 1.自己的调试版本库是否在这个版本或者以上 2.编辑器是不能看到小窗效果的,只能在真机运行 3.播放的内容是否有效,是否能播放 4.跳转页面时内容是否处于播放状态 5.当前页面是 ...

  7. Oracle中对数字加汉字的排序

    需求:有一列NAME, varchar2类型,内容如下 以上就是已经按order by name进行排序的,但不是我们想要的结果 现在需要只按数字进行排序 第一步:抽取数字由于数字有是一位的有是两位的 ...

  8. 一个开源的C#和cefsharp项目:逐浪字体大师pc版上线(附源码开源)

    z01逐浪字体大师,是一款基于C#和web引擎开发的字体设计软件,可以打开直接写字,也可以链接官方资源 ,附Github开源库,欢迎大家下载.客户端技术是基于wpf设计的,整个界面精美,与逐浪CMS技 ...

  9. .NET Protobuf包装器库

    Wodsoft Protobuf Wrapper 内容 关于 需求 安装 用法 序列化 反序列化 字段定义 字段排序 非空构造函数对象 获取Protobuf包装器 高级 支持的属性类型与Protobu ...

  10. 团队内部密码共享方案:KeePassXC+微盘(企业微信)

    目录 需求描述 适用场景 安装使用 KeePassXC初始化 浏览器插件安装设置 1.火狐 2.Edge 3.Chrome 软件-插件的链接 登陆网站并保存密码 (企业微信)微盘共享内部数据库 其他 ...