Katharopoulos A, Fleuret F. Not All Samples Are Created Equal: Deep Learning with Importance Sampling[J]. arXiv: Learning, 2018.

@article{katharopoulos2018not,

title={Not All Samples Are Created Equal: Deep Learning with Importance Sampling},

author={Katharopoulos, Angelos and Fleuret, F},

journal={arXiv: Learning},

year={2018}}

本文提出一种删选合适样本的方法, 这种方法基于收敛速度的一个上界, 而并非完全基于gradient norm的方法, 使得计算比较简单, 容易实现.

主要内容

设\((x_i,y_i)\)为输入输出对, \(\Psi(\cdot;\theta)\)代表网络, \(\mathcal{L}(\cdot, \cdot)\)为损失函数, 目标为

\[\tag{1}
\theta^* = \arg \min_{\theta} \frac{1}{N} \sum_{i=1}^N\mathcal{L}(\Psi(x_i;\theta),y_i),
\]

其中\(N\)是总的样本个数.

假设在第\(t\)个epoch的时候, 样本(被选中)的概率分布为\(p_1^t,\ldots,p_N^t\), 以及梯度权重为\(w_1^t, \ldots, w_N^t\), 那么\(P(I_t=i)=p_i^t\)且

\[\tag{2}
\theta_{t+1}=\theta_t-\eta w_{I_t}\nabla_{\theta_t} \mathcal{L}(\Psi(x_{I_t};\theta_t),y_{I_t}),
\]

在一般SGD训练中\(p_i=1/N,w_i=1\).

定义\(S\)为SGD的收敛速度为:

\[\tag{3}
S :=-\mathbb{E}_{P_t}[\|\theta_{t+1}-\theta^*\|_2^2-\|\theta_t-\theta^*\|_2^2],
\]

如果我们令\(w_i=\frac{1}{Np_i}\) 则



定义\(G_i=w_i\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\)



我们自然希望\(S\)能够越大越好, 此时即负项越小越好.

定义\(\hat{G}_i \ge \|\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\|_2\), 既然



(7)式我有点困惑,我觉得(7)式右端和最小化(6)式的负项(\(\mathrm{Tr}(\mathbb{V}_{P_t}[G_{I_t}])+\|\mathbb{E}_{P_t}[G_{I_t}]\|_2^2\))是等价的.

于是有

最小化右端(通过拉格朗日乘子法)可得\(p_i \propto \hat{G}_i\), 所以现在我们只要找到一个\(\hat{G}_i\)即可.

这个部分需要引入神经网络的反向梯度的公式, 之前有讲过,只是论文的符号不同, 这里不多赘诉了.

注意\(\rho\)的计算是比较复杂的, 但是\(p_i \propto \hat{G}_i\), 所以我们只需要计算\(\|\cdot\|\)部分, 设此分布为\(g\).

另外, 在最开始的时候, 神经网络没有得到很好的训练, 权重大小相差无几, 这个时候是近似正态分布的, 所以作者考虑设计一个指标,来判断是否需要根据样本分布\(g\)来挑选样本. 作者首先衡量



显然当这部分足够大的时候我们可以采用分布\(g\)而非正态分布\(u\), 但是这个指标不易判断, 作者进步除以\(\mathrm{Tr}(\mathbb{V}_u[G_i])\).



显然\(\tau\)越大越好, 我们自然可以人为设置一个\(\tau_{th}\). 算法如下

最后, 个人认为这个算法能减少计算量主要是因为样本少了, 少在一开始用正态分布抽取了一部分, 所以...

"代码"

主要是\(\hat{G}_i\)部分的计算, 因为涉及到中间变量的导数, 所以需要用到retain_grad().

"""
这里只是一个例子
""" import torch
import torch.nn as nn class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.dense = nn.Sequential(
nn.Linear(10, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
self.final = nn.ReLU() def forward(self, x):
z = self.dense(x)
z.retain_grad()
out = self.final(z)
return out, z if __name__ == "__main__": net = Net()
criterion = nn.MSELoss() x = torch.rand((2, 10))
y = torch.rand((2, 10)) out, z = net(x)
loss = criterion(out, y)
loss.backward()
print(z.grad) #这便是我们所需要的

Not All Samples Are Created Equal: Deep Learning with Importance Sampling的更多相关文章

  1. Accelerating Deep Learning by Focusing on the Biggest Losers

    目录 概 相关工作 主要内容 代码 Accelerating Deep Learning by Focusing on the Biggest Losers 概 思想很简单, 在训练网络的时候, 每个 ...

  2. (转) The major advancements in Deep Learning in 2016

    The major advancements in Deep Learning in 2016 Pablo Tue, Dec 6, 2016 in MACHINE LEARNING DEEP LEAR ...

  3. Deep Learning in R

    Introduction Deep learning is a recent trend in machine learning that models highly non-linear repre ...

  4. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  5. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  6. Deep Learning 5_深度学习UFLDL教程:PCA and Whitening_Exercise(斯坦福大学深度学习教程)

    前言 本文是基于Exercise:PCA and Whitening的练习. 理论知识见:UFLDL教程. 实验内容:从10张512*512自然图像中随机选取10000个12*12的图像块(patch ...

  7. (转)WHY DEEP LEARNING IS SUDDENLY CHANGING YOUR LIFE

    Main Menu Fortune.com       E-mail Tweet Facebook Linkedin Share icons By Roger Parloff Illustration ...

  8. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  9. (转)The 9 Deep Learning Papers You Need To Know About (Understanding CNNs Part 3)

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About The 9 Deep Learning Papers You Need To Know Abo ...

随机推荐

  1. .Net 下高性能分表分库组件-连接模式原理

    ShardingCore ShardingCore 一款ef-core下高性能.轻量级针对分表分库读写分离的解决方案,具有零依赖.零学习成本.零业务代码入侵. Github Source Code 助 ...

  2. day02 web主流框架

    day02 web主流框架 今日内容概要 手写简易版本web框架 借助于wsgiref模块 动静态网页 jinja2模板语法 前端.web框架.数据库三种结合 Python主流web框架 django ...

  3. Flume(四)【配置文件总结】

    目录 一.Agent 二.Source taildir arvo netstat exec spooldir 三.Sink hdfs kafka(待续) hbase(待续) arvo logger 本 ...

  4. linux 常用清空文件方法

    1.vim 编辑器 vim /tmp/file :1,$d  或 :%d 2.cat 命令 cat /dev/null > /tmp/file

  5. class.getName()和class.getSimpleName()的区别

    根据API中的定义: Class.getName():以String的形式,返回Class对象的"实体"名称: Class.getSimpleName():获取源代码中给出的&qu ...

  6. 【JAVA今法修真】 第三章 关系非关系 redis法器

    您好,我是南橘,万法仙门的掌门,刚刚从九州世界穿越到地球,因为时空乱流的影响导致我的法力全失,现在不得不通过这个平台向广大修真天才们借去力量.你们的每一个点赞,每一个关注都是让我回到九州世界的助力,兄 ...

  7. vm16虚拟机安装win11

    vm16虚拟机安装win11 参考https://baijiahao.baidu.com/s?id=1712702900207158969&wfr=spider&for=pc win1 ...

  8. hibernate多对多单向(双向)关系映射

    n-n(多对多)的关联关系必须通过连接表实现.下面以商品种类和商品之间的关系,即一个商品种类下面可以有多种商品,一种商品又可以属于多个商品种类,分别介绍单向的n-n关联关系和双向的n-n关联关系. 单 ...

  9. 2021 中国.NET开发者峰会近50场热点技术专题揭秘

    01 大会介绍  .NET Conf China 2021 是面向开发人员的社区峰会,基于 .NET Conf 2021的活动,庆祝 .NET 6 的发布和回顾过去一年来 .NET 在中国的发展成果展 ...

  10. ORA-31633:unable to create master table "DP.SYS_EXPORT_FULL_11" ORA-01658

    问题描述:在进行数据泵进行数据库备份的时候,但是导出命令报错,环境是19C 4节点的rac 一体机.目前磁盘空间需要清理,清理之前先备份一下数据库 ORA-31626:job does not exi ...