【炼丹Trick】EMA的原理与实现
在进行深度学习训练时,同一模型往往可以训练出不同的效果,这就是炼丹这件事的玄学所在。使用一些trick能够让你更容易追上目前SOTA的效果,一些流行的开源代码中已经集成了不少trick,值得学习一番。本节介绍EMA这一方法。
1.原理:
EMA也就是指数移动平均(Exponential moving average)。其公式非常简单,如下所示:
\(\theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}\)
\(\theta_{t}\)是t时刻的网络参数,\(\theta_{\text{EMA}, t}\)是t时刻滑动平均后的网络参数,那么t+1时刻的滑动平均结果就是这两者的加权融合。这里 \(\lambda\)通常会取接近于1的数,比如0.9995,数字越大平均的效果就比较强。
值得注意的是,这里可以看成有两个模型,基础模型其参数按照常规的前后向传播来更新,另外一个模型则是基础模型的滑动平均版本,它并不直接参与前后向传播,仅仅是利用基础模型的参数结果来更新自己。
EMA为什么会有效呢?大概是因为在训练的时候,会使用验证集来衡量模型精度,但其实验证集精度并不和测试集一致,在训练后期阶段,模型可能已经在测试集最佳精度附近波动,所以使用滑动平均的结果会比使用单一结果更加可靠。感兴趣的话可以看看这几篇论文,论文1,论文2,论文3。
2.实现:
Pytorch其实已经为我们实现了这一功能,为了避免自己造轮子可能引入的错误,这里直接学习一下官方的代码。这个类的名称就叫做AveragedModel。代码如下所示。
我们需要做的是提供avg_fn这个函数,avg_fn用来指定以何种方式进行平均。
class AveragedModel(Module):
"""
You can also use custom averaging functions with `avg_fn` parameter.
If no averaging function is provided, the default is to compute
equally-weighted average of the weights.
"""
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
if avg_fn is None:
def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers else self.parameters()
)
model_param = (
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers else model.parameters()
)
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1
@torch.no_grad()
def update_bn(loader, model, device=None):
r"""Updates BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
statistics for BatchNorm layers in the model.
Args:
loader (torch.utils.data.DataLoader): dataset loader to compute the
activation statistics on. Each data batch should be either a
tensor, or a list/tuple whose first element is a tensor
containing data.
model (torch.nn.Module): model for which we seek to update BatchNorm
statistics.
device (torch.device, optional): If set, data will be transferred to
:attr:`device` before being passed into :attr:`model`.
Example:
>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)
.. note::
The `update_bn` utility assumes that each data batch in :attr:`loader`
is either a tensor or a list or tuple of tensors; in the latter case it
is assumed that :meth:`model.forward()` should be called on the first
element of the list or tuple corresponding to the data batch.
"""
momenta = {}
for module in model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
momenta[module] = module.momentum
if not momenta:
return
was_training = model.training
model.train()
for module in momenta.keys():
module.momentum = None
module.num_batches_tracked *= 0
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
if device is not None:
input = input.to(device)
model(input)
for bn_module in momenta.keys():
bn_module.momentum = momenta[bn_module]
model.train(was_training)
这里同样参考官方的示例代码,给出滑动平均的实现。ExponentialMovingAverage继承了AveragedModel,并且复写了init方法,其实更直接的方法是将ema_avg函数作为参数传递给AveragedModel,这里可能是为了可读性,避免出现一个孤零零的ema_avg函数。
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu"):
def ema_avg(avg_model_param, model_param, num_averaged):
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=True)
如何使用呢?方式是比较简单的,首先是利用当前模型创建出一个滑动平均模型。
model_ema = utils.ExponentialMovingAverage(model, device=device, decay=ema_decay)
然后是进行基础模型的前后向传播,更新结束后再对滑动平均版的模型进行参数更新。
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_ema.update_parameters(model)
【炼丹Trick】EMA的原理与实现的更多相关文章
- 【优化技巧】指数移动平均EMA的原理
前言 在深度学习中,经常会使用EMA(exponential moving average)方法对模型的参数做平滑或者平均,以求提高测试指标,增加模型鲁棒性. 参考 1. [优化技巧]指数移动平均(E ...
- 炼丹的一些trick
采摘一些大佬的果实: 知乎:如何理解深度学习分布式训练中的large batch size与learning rate的关系? https://blog.csdn.net/shanglianlm/ar ...
- PHP 底层的运行机制与原理
PHP说简单,但是要精通也不是一件简单的事.我们除了会使用之外,还得知道它底层的工作原理. PHP是一种适用于web开发的动态语言.具体点说,就是一个用C语言实现包含大量组件的软件框架.更狭义点看,可 ...
- JSPatch 实现原理详解
原文地址https://github.com/bang590/JSPatch/wiki/JSPatch-%E5%AE%9E%E7%8E%B0%E5%8E%9F%E7%90%86%E8%AF%A6%E8 ...
- PHP的运行机制与原理(底层) [转]
说到php的运行机制还要先给大家介绍php的模块,PHP总共有三个模块:内核.Zend引擎.以及扩展层:PHP内核用来处理请求.文件流.错误处理等相关操作:Zend引擎(ZE)用以将源文件转换成机器语 ...
- Linux进程调度原理
Linux进程调度原理 Linux进程调度机制 Linux进程调度的目标 1.高效性:高效意味着在相同的时间下要完成更多的任务.调度程序会被频繁的执行,所以调度程序要尽可能的高效: 2.加强交互性能: ...
- PHP底层的运行机制与原理
PHP说简单,但是要精通也不是一件简单的事.我们除了会使用之外,还得知道它底层的工作原理. PHP是一种适用于web开发的动态语言.具体点说,就是一个用C语言实现包含大量组件的软件框架.更狭义点看,可 ...
- 单片微机原理P0:80C51结构原理
本来我真的不想让51的东西出现在我的博客上的,因为51这种东西真的太low了,学了最多就所谓的垃圾科创利用一下,但是想一下这门课我也要考试,还是写一点东西顺便放博客上吧. 这一系列主要参考<单片 ...
- Kernel PCA 原理和演示
Kernel PCA 原理和演示 主成份(Principal Component Analysis)分析是降维(Dimension Reduction)的重要手段.每一个主成分都是数据在某一个方向上的 ...
随机推荐
- prometheus监控预警之AlertManager邮箱报警
Alertmanager 主要用于接收 Prometheus 发送的告警信息,它支持丰富的告警通知渠道,例如邮件.微信.钉钉.Slack 等常用沟通工具,而且很容易做到告警信息进行去重,降噪,分组等, ...
- Halo 开源项目学习(三):注册与登录
基本介绍 首次启动 Halo 项目时需要安装博客并注册用户信息,当博客安装完成后用户就可以根据注册的信息登录到管理员界面,下面我们分析一下整个过程中代码是如何执行的. 博客安装 项目启动成功后,我们可 ...
- Golang 实现 Redis(10): 本地原子性事务
为了支持多个命令的原子性执行 Redis 提供了事务机制. Redis 官方文档中称事务带有以下两个重要的保证: 事务是一个单独的隔离操作:事务中的所有命令都会序列化.按顺序地执行.事务在执行的过程中 ...
- 通过代码解释什么是API,什么是SDK?
这个问题说来惭愧,读书时找实习面的第一家公司,问的第一个问题就是这个. 当时我没能说清楚,回去之后就上百度查.结果查了很久还是看不懂,然后就把这个问题搁置了. 谁知道毕业正式工作后,又再一次地面对了这 ...
- 【Java分享客栈】超简洁SpringBoot使用AOP统一日志管理-纯干货干到便秘
前言 请问今天您便秘了吗?程序员坐久了真的会便秘哦,如果偶然点进了这篇小干货,就麻烦您喝杯水然后去趟厕所一边用左手托起对准嘘嘘,一边用右手滑动手机看完本篇吧. 实现 本篇AOP统一日志管理写法来源于国 ...
- 实验:Python图形图像处理
1. 准备一张照片,编写Python程序将该照片进行图像处理,分别输出以下效果的图片:(a)灰度图:(b)轮廓图: (c)变换RGB通道图:(d)旋转45度图. 2. 假设当前文件夹中data.csv ...
- SSO 方案演进
背景介绍 随着业务与技术的发展,现今比以往任何时候都更需要单点登录 SSO 身份验证. 现在几乎每个网站都需要某种形式的身份验证才能访问其功能和内容. 随着网站和服务数量的增加,集中登录系统已成为一种 ...
- 通过CSS让图片变的清楚
image { width: 100%; height: 100%; border-radius: 10upx; //让图片变清楚 image-rendering: -moz-crisp-edges; ...
- python目录索引
python目录索引 python基础数据类型1 目录 part1 part2 运算符 格式化 part3 字符串 字符串常用操作方法 part4 列表 列表的创建: 列表的索引,切片 列表的增删改查 ...
- JavaSE_关键字 接口 代码块 枚举
1 Java中的关键字 1.1 static关键字 static特点 : 静态成员被所在类的所有对象共享 随着类的加载而加载 , 优先于对象存在 可以通过对象调用 , 也可以通过类名调用 , 建议使用 ...