转自:理解滑动平均(exponential moving average)

1. 用滑动平均估计局部均值

  滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关。

  变量vv在tt时刻记为 vtvt,θtθt 为变量 vv 在 tt 时刻的取值,即在不使用滑动平均模型时 vt=θtvt=θt,在使用滑动平均模型后,vtvt 的更新公式如下:

vt=β⋅vt−1+(1−β)⋅θt(1)(1)vt=β⋅vt−1+(1−β)⋅θt

  上式中,β∈[0,1)β∈[0,1)。β=0β=0 相当于没有使用滑动平均。

  假设起始 v0=0v0=0,β=0.9β=0.9,之后每个时刻,依次对变量 vv 进行赋值,不使用滑动平均和使用滑动平均结果如下:

表 1  三种变量更新方式

t 不使用滑动平均模型,即给vv直接赋值θθ

使用滑动平均模型,
按照公式(1)更新vv

使用滑动平均模型,
按照公式(2)更新v_biasedv_biased

0, 1, 2,  ... ,  35 [0, 10, 20, 10, 0, 10, 20, 30, 5, 0, 10, 20, 10, 0, 10, 20, 30, 5, 0, 10, 20, 10, 0, 10, 20, 30, 5, 0, 10, 20, 10, 0, 10, 20, 30, 5] [0, 1.0, 2.9, 3.61, 3.249, 3.9241, 5.5317, 7.9785, 7.6807, 6.9126, 7.2213, 8.4992, 8.6493, 7.7844, 8.0059, 9.2053, 11.2848, 10.6563, 9.5907, 9.6316, 10.6685, 10.6016, 9.5414, 9.5873, 10.6286, 12.5657, 11.8091, 10.6282, 10.5654, 11.5089, 11.358, 10.2222, 10.2, 11.18, 13.062, 12.2558] [0, 10.0, 15.2632, 13.321, 9.4475, 9.5824, 11.8057, 15.2932, 13.4859, 11.2844, 11.0872, 12.3861, 12.0536, 10.4374, 10.3807, 11.592, 13.8515, 12.7892, 11.2844, 11.1359, 12.145, 11.9041, 10.5837, 10.5197, 11.5499, 13.5376, 12.6248, 11.2844, 11.1489, 12.0777, 11.8608, 10.6276, 10.5627, 11.5365, 13.4357, 12.5704]

图 1:三种变量更新方式

  Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,tt 时刻变量 vv 的滑动平均值大致等于过去 1/(1−β)1/(1−β) 个时刻 θθ 值的平均。这个结论在滑动平均起始时相差比较大,所以有了Bias correction,将 vtvt 除以 (1−βt)(1−βt) 修正对均值的估计。

  加入了Bias correction后,vtvt 和 v_biasedtv_biasedt 的更新公式如下:

vt=β⋅vt−1+(1−β)⋅θtv_biasedt=vt1−βt(2)(2)vt=β⋅vt−1+(1−β)⋅θtv_biasedt=vt1−βt

tt 越大,1−βt1−βt 越接近 1,则公式(1)和(2)得到的结果 (vtvt 和 v_biasedtv_biasedt)将越来越近,如图 1 所示。

  当 ββ 越大时,滑动平均得到的值越和 θθ 的历史值相关。如果 β=0.9β=0.9,则大致等于过去 10 个 θθ 值的平均;如果 β=0.99β=0.99,则大致等于过去 100 个 θθ 值的平均。

  滑动平均的好处:

占内存少,不需要保存过去10个或者100个历史 θθ 值,就能够估计其均值。(当然,滑动平均不如将历史值全保存下来计算均值准确,但后者占用更多内存和计算成本更高)

2. TensorFlow中使用滑动平均来更新变量(参数)

  滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某次的异常取值而使得滑动平均值波动很大,如图 1所示。

  TensorFlow 提供了 tf.train.ExponentialMovingAverage 来实现滑动平均。在初始化 ExponentialMovingAverage 时,需要提供一个衰减率(decay),即公式(1)(2)中的 ββ。这个衰减率将用于控制模型的更新速度。ExponentialMovingAverage 对每一个变量(variable)会维护一个影子变量(shadow_variable),这个影子变量的初始值就是相应变量的初始值,而每次运行变量更新时,影子变量的值会更新为:

shadow_variable=decay⋅shadow_variable+(1−decay)⋅variable(3)(3)shadow_variable=decay⋅shadow_variable+(1−decay)⋅variable

公式(3)中的 shadow_variable 就是公式(1)中的 vtvt,公式(3)中的 variable 就是公式(1)中的 θtθt,公式(3)中的 decay 就是公式(1)中的 ββ。

  公式(3)中,decay 决定了影子变量的更新速度,decay 越大影子变量越趋于稳定。在实际运用中,decay一般会设成非常接近 1 的数(比如0.999或0.9999)。为了使得影子变量在训练前期可以更新更快,ExponentialMovingAverage 还提供了 num_updates 参数动态设置 decay 的大小。如果在初始化 ExponentialMovingAverage 时提供了 num_updates 参数,那么每次使用的衰减率将是:

min{decay,1+num_updates10+num_updates}(4)(4)min{decay,1+num_updates10+num_updates}

这一点其实和 Bias correction 很像。

  TensorFlow 中使用 ExponentialMovingAverage 的例子:code (如果 GitHub 无法加载 .ipynb 文件,则将 .ipynb 文件的 URL 复制到网站 https://nbviewer.jupyter.org/

3. 滑动平均为什么在测试过程中被使用?

  滑动平均可以使模型在测试数据上更健壮(robust)。“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”

  对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量 shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,不然无法得到 weights 下一步更新的值,又怎么求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用 shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远;对于梯度下降 batch gradient decent,我感觉影子变量作用不大,因为梯度下降的方向已经是最优的了,loss 一定减小;对于 mini-batch gradient decent,可以尝试滑动平均,毕竟 mini-batch gradient decent 对参数的更新也存在抖动。

  设 decay=0.999decay=0.999,一个更直观的理解,在最后的 1000 次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的 1000 次抖动进行了平均,这样得到的权重会更加 robust。

References

Course 2 Improving Deep Neural Networks by Andrew Ng

《TensorFlow实战Google深度学习框架》 4.4.3

(转)理解滑动平均(exponential moving average)的更多相关文章

  1. 理解滑动平均(exponential moving average)

    1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...

  2. EMA计算的C#实现(c# Exponential Moving Average (EMA) indicator )

    原来国外有个源码(TechnicalAnalysisEngine src 1.25)内部对EMA的计算是: var copyInputValues = input.ToList(); for (int ...

  3. (转)滑动平均法、滑动平均模型算法(Moving average,MA)

    原文链接:https://blog.csdn.net/qq_39521554/article/details/79028012 什么是移动平均法? 移动平均法是用一组最近的实际数据值来预测未来一期或几 ...

  4. 一文详解滑动平均法、滑动平均模型法(Moving average,MA)

    任何关于算法.编程.AI行业知识或博客内容的问题,可以随时扫码关注公众号「图灵的猫」,加入”学习小组“,沙雕博主在线答疑~此外,公众号内还有更多AI.算法.编程和大数据知识分享,以及免费的SSR节点和 ...

  5. [leetcode]346. Moving Average from Data Stream滑动窗口平均值

    Given a stream of integers and a window size, calculate the moving average of all integers in the sl ...

  6. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  7. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  8. tensorflow入门笔记(二) 滑动平均模型

    tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均. 当训练模型的时候,保持训练参数的滑动平均是非常有益的.评估时使用取平均 ...

  9. deep_learning_Function_tf.train.ExponentialMovingAverage()滑动平均

    近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录. tf.train.ExponentialMovingA ...

随机推荐

  1. 【CHRIS RICHARDSON 微服务系列】微服务架构中的进程间通信-3

    编者的话 |本文来自 Nginx 官方博客,是微服务系列文章的第三篇,在第一篇文章中介绍了微服务架构模式,与单体模式进行了比较,并且讨论了使用微服务架构的优缺点.第二篇描述了采用微服务架构的应用客户端 ...

  2. Leetcode题解 - 贪心算法部分简单题目代码+思路(860、944、1005、1029、1046、1217、1221)

    leetcode真的是一个学习阅读理解的好地方 860. 柠檬水找零 """ 因为用户支付的只会有5.10.20 对于10元的用户必须找一个5 对于20元的用户可以找(三 ...

  3. JDK1.8新特性-Lambda表达式

    Lambda 表达式 Lambda 表达式,也可称为闭包,它是推动 Java 8 发布的最重要新特性. Lambda 允许把函数作为一个方法的参数(函数作为参数传递进方法中). 使用 Lambda 表 ...

  4. MySQL MHA /usr/share/perl5/vendor_perl/MHA/ServerManager.pm, ln301] install_driver(mysql) failed: Attempt to reload DBD/mysql.pm aborted

    在公司随便找3台测试机搭个MHA,下面这个问题折腾了三天,之前没遇到过,查了OS版本发现一致,可能是不同人弄的OS吧,知道是cpan的问题就是搞不定,郁闷...[root@test247 ~]# ma ...

  5. Bash脚本编程之脚本基础和bash配置文件

    脚本基础 参考资料:Shell Scripts (Bash Reference Manual) 不严谨地说,编程语言根据代码运行的方式,可以分为两种方式: 编译运行:需要先将人类可识别的代码文件编译成 ...

  6. 看懂 游戏《Minecraft》的崩溃报告 服务端/客户端

    如何看懂Minecraft报错的关键信息. 让你如何看懂Minecraft报错 前言 一些俏皮话 寻找崩溃日志 打开崩溃日志 重要的事说三遍 下载文本编辑器 开始分析 深度分析 得出结论 修复报错 解 ...

  7. HTML DOM的创建,删除及替换

    创建HTML元素 document.appendChild() 将新元素作为父元素的最后一个子元素进行添加 如需向HTML DOM添加新元素,首先必须创建该元素,然后把它追加到已有的元素上 var n ...

  8. Java题库——Chapter15 事件驱动编程和动画

    Chapter 15 Event-Driven Programming and Animations Section 15.2 Events and Event Sources1.    A Java ...

  9. 使用kubernetes的cronjob定时备份mysql数据库

    1.创建cronjob的文件 CronJob所描述的,正是定时任务. 在给定时间点只运行一次 在给定时间点周期性地运行 一个 CronJob 对象类似于 crontab (cron table)文件中 ...

  10. 【重要更新】Senparc.Weixin SDK v6.5 升级说明(支持 .NET Core 3.0 及分布式消息上下文)

    Senparc.Weixin SDK v6.5 开始支持 .NET Core 3.0,并将微信消息上下文进行了大幅度的重构,支持了使用分布式缓存存储上下文信息,这意味着在分布式系统中,现在 Senpa ...