Katharopoulos A, Fleuret F. Not All Samples Are Created Equal: Deep Learning with Importance Sampling[J]. arXiv: Learning, 2018.

@article{katharopoulos2018not,

title={Not All Samples Are Created Equal: Deep Learning with Importance Sampling},

author={Katharopoulos, Angelos and Fleuret, F},

journal={arXiv: Learning},

year={2018}}

本文提出一种删选合适样本的方法, 这种方法基于收敛速度的一个上界, 而并非完全基于gradient norm的方法, 使得计算比较简单, 容易实现.

主要内容

设\((x_i,y_i)\)为输入输出对, \(\Psi(\cdot;\theta)\)代表网络, \(\mathcal{L}(\cdot, \cdot)\)为损失函数, 目标为

\[\tag{1}
\theta^* = \arg \min_{\theta} \frac{1}{N} \sum_{i=1}^N\mathcal{L}(\Psi(x_i;\theta),y_i),
\]

其中\(N\)是总的样本个数.

假设在第\(t\)个epoch的时候, 样本(被选中)的概率分布为\(p_1^t,\ldots,p_N^t\), 以及梯度权重为\(w_1^t, \ldots, w_N^t\), 那么\(P(I_t=i)=p_i^t\)且

\[\tag{2}
\theta_{t+1}=\theta_t-\eta w_{I_t}\nabla_{\theta_t} \mathcal{L}(\Psi(x_{I_t};\theta_t),y_{I_t}),
\]

在一般SGD训练中\(p_i=1/N,w_i=1\).

定义\(S\)为SGD的收敛速度为:

\[\tag{3}
S :=-\mathbb{E}_{P_t}[\|\theta_{t+1}-\theta^*\|_2^2-\|\theta_t-\theta^*\|_2^2],
\]

如果我们令\(w_i=\frac{1}{Np_i}\) 则



定义\(G_i=w_i\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\)



我们自然希望\(S\)能够越大越好, 此时即负项越小越好.

定义\(\hat{G}_i \ge \|\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\|_2\), 既然



(7)式我有点困惑,我觉得(7)式右端和最小化(6)式的负项(\(\mathrm{Tr}(\mathbb{V}_{P_t}[G_{I_t}])+\|\mathbb{E}_{P_t}[G_{I_t}]\|_2^2\))是等价的.

于是有

最小化右端(通过拉格朗日乘子法)可得\(p_i \propto \hat{G}_i\), 所以现在我们只要找到一个\(\hat{G}_i\)即可.

这个部分需要引入神经网络的反向梯度的公式, 之前有讲过,只是论文的符号不同, 这里不多赘诉了.

注意\(\rho\)的计算是比较复杂的, 但是\(p_i \propto \hat{G}_i\), 所以我们只需要计算\(\|\cdot\|\)部分, 设此分布为\(g\).

另外, 在最开始的时候, 神经网络没有得到很好的训练, 权重大小相差无几, 这个时候是近似正态分布的, 所以作者考虑设计一个指标,来判断是否需要根据样本分布\(g\)来挑选样本. 作者首先衡量



显然当这部分足够大的时候我们可以采用分布\(g\)而非正态分布\(u\), 但是这个指标不易判断, 作者进步除以\(\mathrm{Tr}(\mathbb{V}_u[G_i])\).



显然\(\tau\)越大越好, 我们自然可以人为设置一个\(\tau_{th}\). 算法如下

最后, 个人认为这个算法能减少计算量主要是因为样本少了, 少在一开始用正态分布抽取了一部分, 所以...

"代码"

主要是\(\hat{G}_i\)部分的计算, 因为涉及到中间变量的导数, 所以需要用到retain_grad().

"""
这里只是一个例子
""" import torch
import torch.nn as nn class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.dense = nn.Sequential(
nn.Linear(10, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
self.final = nn.ReLU() def forward(self, x):
z = self.dense(x)
z.retain_grad()
out = self.final(z)
return out, z if __name__ == "__main__": net = Net()
criterion = nn.MSELoss() x = torch.rand((2, 10))
y = torch.rand((2, 10)) out, z = net(x)
loss = criterion(out, y)
loss.backward()
print(z.grad) #这便是我们所需要的

Not All Samples Are Created Equal: Deep Learning with Importance Sampling的更多相关文章

  1. Accelerating Deep Learning by Focusing on the Biggest Losers

    目录 概 相关工作 主要内容 代码 Accelerating Deep Learning by Focusing on the Biggest Losers 概 思想很简单, 在训练网络的时候, 每个 ...

  2. (转) The major advancements in Deep Learning in 2016

    The major advancements in Deep Learning in 2016 Pablo Tue, Dec 6, 2016 in MACHINE LEARNING DEEP LEAR ...

  3. Deep Learning in R

    Introduction Deep learning is a recent trend in machine learning that models highly non-linear repre ...

  4. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  5. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  6. Deep Learning 5_深度学习UFLDL教程:PCA and Whitening_Exercise(斯坦福大学深度学习教程)

    前言 本文是基于Exercise:PCA and Whitening的练习. 理论知识见:UFLDL教程. 实验内容:从10张512*512自然图像中随机选取10000个12*12的图像块(patch ...

  7. (转)WHY DEEP LEARNING IS SUDDENLY CHANGING YOUR LIFE

    Main Menu Fortune.com       E-mail Tweet Facebook Linkedin Share icons By Roger Parloff Illustration ...

  8. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  9. (转)The 9 Deep Learning Papers You Need To Know About (Understanding CNNs Part 3)

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About The 9 Deep Learning Papers You Need To Know Abo ...

随机推荐

  1. 日常Java 2021/10/17

    今天开始Javaweb编译环境调试,从tomcat容器开始,然后mysql的下载,连接工具datagrip,navicat for mysql,然后就是编写自己的sql,安装jdbc,eclipse连 ...

  2. 学习java 7.14

    学习内容: 标准输入输出流 输出语言的本质:是一个标准的输出流 字节打印流 字符打印流 对象序列化流 明天内容: 进程和线程 遇到问题: 用对象序列化流序列化一个对象后,假如我们修改了对象所属的类文件 ...

  3. Spark(七)【RDD的持久化Cache和CheckPoint】

    RDD的持久化 1. RDD Cache缓存 ​ RDD通过Cache或者Persist方法将前面的计算结果缓存,默认情况下会把数据以缓存在JVM的堆内存中.但是并不是这两个方法被调用时立即缓存,而是 ...

  4. 2021广东工业大学新生赛决赛 L-歪脖子树下的灯

    题目:L-歪脖子树下的灯_2021年广东工业大学第11届腾讯杯新生程序设计竞赛(同步赛) (nowcoder.com) 比赛的时候没往dp这方面想(因为之前初赛和月赛数学题太多了啊),因此只往组合数学 ...

  5. 接口测试 python+PyCharm 环境搭建

    1.配置Python环境变量 a:我的电脑->属性->高级系统设置->环境变量->系统变量中的PATH变量. 变量名:PATH      修改变量值为:;C:\Python27 ...

  6. activiti工作流引擎

    参考文章 Activiti-5.18.0与springMvc项目集成和activiti-explorer单独部署Web项目并与业务数据库关联方法(AutoEE_V2实现方式) https://blog ...

  7. OSGi系列 - 使用Eclipse查看Bundle源码

    使用Eclipse开发OSGi Bundle时,会发现有很多现成的Bundle可以用.但如何使用这些Bundle呢?除了上网搜索查资料外,阅读这些Bundle的源码也是一个很好的方法. 本文以org. ...

  8. 【Linux】【Commands】文件管理工具

    文件管理工具:cp, mv, rm cp命令:copy 源文件:目标文件 单源复制:cp [OPTION]... [-T] SOURCE DEST 多源复制:cp [OPTION]... SOURCE ...

  9. JpaRepository 增删改查

    Jpa查询 JpaRepository简单查询 基本查询也分为两种,一种是spring data默认已经实现,一种是根据查询的方法来自动解析成SQL. 预先生成方法 spring data jpa 默 ...

  10. java 输入输出IO流 IO异常处理try(IO流定义){IO流使用}catch(异常){处理异常}finally{死了都要干}

    IO异常处理 之前我们写代码的时候都是直接抛出异常,但是我们试想一下,如果我们打开了一个流,在关闭之前程序抛出了异常,那我们还怎么关闭呢?这个时候我们就要用到异常处理了. try-with-resou ...