1.优化器算法简述

首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。

2.Batch Gradient Descent (BGD)

梯度更新规则:

BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:

 

缺点:

由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。

for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad

我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。

Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。

3.Stochastic Gradient Descent (SGD)

梯度更新规则:

和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad

看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。

缺点:

SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。

BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。

当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。

4.Mini-Batch Gradient Descent (MBGD)

梯度更新规则:

MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。

for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad

超参数设定值:  n 一般取值在 50~256

缺点:(两大缺点)

  1. 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
  2. SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。

5.Adam:Adaptive Moment Estimation

这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum

除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8

实践表明,Adam 比其他适应性学习方法效果要好。

参考文献:

https://www.cnblogs.com/guoyaohua/p/8542554.html

Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章

  1. 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)

    在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...

  2. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  3. 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集

    机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...

  4. Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]

     Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...

  5. (转)live555学习笔记10-h264 RTP传输详解(2)

    参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...

  6. 大数据学习笔记——Spark工作机制以及API详解

    Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...

  7. 学习笔记--Grunt、安装、图文详解

    学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...

  8. Java8学习笔记(五)--Stream API详解[转]

    为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...

  9. 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解

    源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...

随机推荐

  1. 最新JS正则表达式验证手机号码(2019)

    根据移动.联通.电信的电话号码号段,实现一个简单的正则表达式来验证手机号码: // 手机号校验 export function isPhoneNumber(phoneNum) { // let reg ...

  2. 转载:10G以太网光口与Aurora接口回环实验

    10G以太网光口与高速串行接口的使用越来越普遍,本文拟通过一个简单的回环实验,来说明在常见的接口调试中需要注意的事项.各种Xilinx FPGA接口学习的秘诀:Example Design.欢迎探讨. ...

  3. 从零开始,无DNS vcenter 6.7 vmotion热迁移,存储集群部署文档。

    1,环境准备 准备:Vmware workstation环境 IP地址段规划 ESXI主机IP地址段 192.168.197.4-192.168.197.10 Vcenter Server集群IP地址 ...

  4. [mysql课程作业]我的大学|作业

    第八周周五 1.将xs表中王元的专业改为"智能建筑". # update xs set 专业名='智能建筑' where 姓名='王元'; # select * from xs w ...

  5. 这一篇 K8S(Kubernetes)集群部署 我觉得还可以!!!

    点赞再看,养成习惯,微信搜索[牧小农]关注我获取更多资讯,风里雨里,小农等你,很高兴能够成为你的朋友. 国内安装K8S的四种途径 Kubernetes 的安装其实并不复杂,因为Kubernetes 属 ...

  6. 攻防世界 Misc 新手练习区 give_you_flag Writeup

    攻防世界 Misc 新手练习区 give_you_flag Writeup 题目介绍 题目考点 gif图片分离 细心的P图 二维码解码 Writeup 下载附件打开,发现是一张gif图片,打开看了一下 ...

  7. 攻防世界 WEB 高手进阶区 csaw-ctf-2016-quals mfw Writeup

    攻防世界 WEB 高手进阶区 csaw-ctf-2016-quals mfw Writeup 题目介绍 题目考点 PHP代码审计 git源码泄露 Writeup 进入题目,点击一番,发现可能出现git ...

  8. Centos8上安装Mysql8.X

    一.下载Mysql 下载地址:https://dev.mysql.com/downloads/mysql/ 二.将压缩包通过ftp软件服务器的目标位置:并解压 1.我的是放在:/root/softwa ...

  9. 菜鸡的Java笔记 第六 - java 方法

    前提:现在所讲解的方法定义格式,只属于JAVA 方法定义的其中一种组成方式.而完整的组成方式将随着学习逐步渗透. 1.方法的基本定义 方法(Method)在一些书中也会有人将其说是  函数(Funct ...

  10. [hdu4388]Stone Game II

    不管是否使用技能,发现操作前后所有堆二进制中1的个数之和不变.那么对于一个堆其实可以等价转换为一个k个石子的堆(k为该数二进制的个数),然后就是个nim游戏. 1 #include<bits/s ...