Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)
1.优化器算法简述
首先来看一下梯度下降最常见的三种变形 BGD,SGD,MBGD,这三种形式的区别就是取决于我们用多少数据来计算目标函数的梯度,这样的话自然就涉及到一个 trade-off,即参数更新的准确率和运行时间。
2.Batch Gradient Descent (BGD)
梯度更新规则:
BGD 采用整个训练集的数据来计算 cost function 对参数的梯度:
缺点:
由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。
for i in range(nb_epochs):
params_grad = evaluate_gradient(loss_function, data, params)
params = params - learning_rate * params_grad
我们会事先定义一个迭代次数 epoch,首先计算梯度向量 params_grad,然后沿着梯度的方向更新参数 params,learning rate 决定了我们每一步迈多大。
Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。
3.Stochastic Gradient Descent (SGD)
梯度更新规则:
和 BGD 的一次用所有数据计算梯度相比,SGD 每次更新时对每个样本进行梯度更新,对于很大的数据集来说,可能会有相似的样本,这样 BGD 在计算梯度时会出现冗余,而 SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。

for i in range(nb_epochs):
np.random.shuffle(data)
for example in data:
params_grad = evaluate_gradient(loss_function, example, params)
params = params - learning_rate * params_grad
看代码,可以看到区别,就是整体数据集是个循环,其中对每个样本进行一次参数更新。

随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。
缺点:
SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。
BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。
当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。
4.Mini-Batch Gradient Descent (MBGD)
梯度更新规则:
MBGD 每一次利用一小批样本,即 n 个样本进行计算,这样它可以降低参数更新时的方差,收敛更稳定,另一方面可以充分地利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。

和 SGD 的区别是每一次循环不是作用于每个样本,而是具有 n 个样本的批次。
for i in range(nb_epochs):
np.random.shuffle(data)
for batch in get_batches(data, batch_size=50):
params_grad = evaluate_gradient(loss_function, batch, params)
params = params - learning_rate * params_grad
超参数设定值: n 一般取值在 50~256
缺点:(两大缺点)
- 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。)对于非凸函数,还要避免陷于局部极小值处,或者鞍点处,因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。(会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。)
- SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的,我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。
5.Adam:Adaptive Moment Estimation
这个算法是另一种计算每个参数的自适应学习率的方法。相当于 RMSprop + Momentum
除了像 Adadelta 和 RMSprop 一样存储了过去梯度的平方 vt 的指数衰减平均值 ,也像 momentum 一样保持了过去梯度 mt 的指数衰减平均值:

如果 mt 和 vt 被初始化为 0 向量,那它们就会向 0 偏置,所以做了偏差校正,通过计算偏差校正后的 mt 和 vt 来抵消这些偏差:

梯度更新规则:

超参数设定值:
建议 β1 = 0.9,β2 = 0.999,ϵ = 10e−8
实践表明,Adam 比其他适应性学习方法效果要好。
参考文献:
https://www.cnblogs.com/guoyaohua/p/8542554.html
Pytorch学习笔记08----优化器算法Optimizer详解(SGD、Adam)的更多相关文章
- 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- pytorch学习笔记(十二):详解 Module 类
Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...
- 机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集
机器学习实战(Machine Learning in Action)学习笔记————08.使用FPgrowth算法来高效发现频繁项集 关键字:FPgrowth.频繁项集.条件FP树.非监督学习作者:米 ...
- Linux防火墙iptables学习笔记(三)iptables命令详解和举例[转载]
Linux防火墙iptables学习笔记(三)iptables命令详解和举例 2008-10-16 23:45:46 转载 网上看到这个配置讲解得还比较易懂,就转过来了,大家一起看下,希望对您工作能 ...
- (转)live555学习笔记10-h264 RTP传输详解(2)
参考: 1,live555学习笔记10-h264 RTP传输详解(2) http://blog.csdn.net/niu_gao/article/details/6936108 2,H264 sps ...
- 大数据学习笔记——Spark工作机制以及API详解
Spark工作机制以及API详解 本篇文章将会承接上篇关于如何部署Spark分布式集群的博客,会先对RDD编程中常见的API进行一个整理,接着再结合源代码以及注释详细地解读spark的作业提交流程,调 ...
- 学习笔记--Grunt、安装、图文详解
学习笔记--Git安装.图文详解 安装Git成功后,现在安装Gruntjs,官网:http://gruntjs.com/ 一.安装node 参考node.js 安装.图文详解 (最新的node会自动安 ...
- Java8学习笔记(五)--Stream API详解[转]
为什么需要 Stream Stream 作为 Java 8 的一大亮点,它与 java.io 包里的 InputStream 和 OutputStream 是完全不同的概念.它也不同于 StAX 对 ...
- 【官方文档】Nginx模块Nginx-Rtmp-Module学习笔记(一) RTMP 命令详解
源码地址:https://github.com/Tinywan/PHP_Experience 说明: rtmp的延迟主要取决于播放器设置,但流式传输软件,流的比特率和网络速度(以及响应时间“ping” ...
随机推荐
- ArrayList集合底层原理
目录 ArrayList集合特点及源码分析 ArrayList源码分析 成员变量 构造函数 增加方法 add(E e)方法 add(int index, E element)方法 删除方法 remov ...
- JAVA笔记10__Math类、Random类、Arrays类/日期操作类/对象比较器/对象的克隆/二叉树
/** * Math类.Random类.Arrays类:具体查JAVA手册...... */ public class Main { public static void main(String[] ...
- 服务集与AP的配合
一.实验目的 1)掌握添加无线网络配置 2)掌握配置信道和协议使用并配置在一个天线上同时运行两个服务集,即两个无线网络 二.实验仪器设备及软件 仪器设备:一台AC,两台AP,一台AR,一台LSW 软件 ...
- loto仪器_如何模拟输出凸轮轴和曲轴波形_用任意波形信号源SIG852?
loto仪器_如何模拟输出凸轮轴和曲轴波形_用任意波形信号源SIG852? 在汽车传感器的波形检测应用中,有时候需要模拟各种汽车传感器的输出信号,用来驱动和监测对应的执行机构或者电路是否正常,这其中, ...
- ADB WIFI无线调试真正摆脱usb数据线连接,一次也不用!
常见的使用ADB无线调试步骤 手机"开发者模式"菜单中开启"USB调试" 和"无线调试",手机网络与电脑在同一网内; 手机使用USB与电脑进 ...
- FastJson 解析、序列化及反序列化
一.环境准备:使用maven特性在pom.xml中导入fastjson的依赖包 <!-- https://mvnrepository.com/artifact/com.alibaba/fastj ...
- 性能工具之代码级性能测试工具ContiPerf
前言 做性能的同学一定遇到过这样的场景:应用级别的性能测试发现一个操作的响应时间很长,然后要花费很多时间去逐级排查,最后却发现罪魁祸首是代码中某个实现低效的底层算法.这种自上而下的逐级排查定位的方法, ...
- elasticsearch在postman中创建复杂索引
body,所选类型为raw和JSON,写的代码为 { "settings":{ "number_of_shards":1, "number_of_re ...
- Linux——搭建Apache(httpd)服务器
一.基本概念 Apache(或httpd)是Internet上使用最多的Web服务器技术之一,使用的传输协议是http超文本传输协议(一个基于超文本的协议),用于通过网络连接来发送和接受对象. 有两个 ...
- [atARC112E]Rvom and Rsrev
毒瘤分类讨论题 (注:以下情况都有"之前的情况都不满足的"前提条件,并用斜体表示一些说明) Case0:若$|s|\le 2$,直接输出即可,因此假设$|s|>3$ 首先,我 ...