Katharopoulos A, Fleuret F. Not All Samples Are Created Equal: Deep Learning with Importance Sampling[J]. arXiv: Learning, 2018.

@article{katharopoulos2018not,

title={Not All Samples Are Created Equal: Deep Learning with Importance Sampling},

author={Katharopoulos, Angelos and Fleuret, F},

journal={arXiv: Learning},

year={2018}}

本文提出一种删选合适样本的方法, 这种方法基于收敛速度的一个上界, 而并非完全基于gradient norm的方法, 使得计算比较简单, 容易实现.

主要内容

设\((x_i,y_i)\)为输入输出对, \(\Psi(\cdot;\theta)\)代表网络, \(\mathcal{L}(\cdot, \cdot)\)为损失函数, 目标为

\[\tag{1}
\theta^* = \arg \min_{\theta} \frac{1}{N} \sum_{i=1}^N\mathcal{L}(\Psi(x_i;\theta),y_i),
\]

其中\(N\)是总的样本个数.

假设在第\(t\)个epoch的时候, 样本(被选中)的概率分布为\(p_1^t,\ldots,p_N^t\), 以及梯度权重为\(w_1^t, \ldots, w_N^t\), 那么\(P(I_t=i)=p_i^t\)且

\[\tag{2}
\theta_{t+1}=\theta_t-\eta w_{I_t}\nabla_{\theta_t} \mathcal{L}(\Psi(x_{I_t};\theta_t),y_{I_t}),
\]

在一般SGD训练中\(p_i=1/N,w_i=1\).

定义\(S\)为SGD的收敛速度为:

\[\tag{3}
S :=-\mathbb{E}_{P_t}[\|\theta_{t+1}-\theta^*\|_2^2-\|\theta_t-\theta^*\|_2^2],
\]

如果我们令\(w_i=\frac{1}{Np_i}\) 则



定义\(G_i=w_i\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\)



我们自然希望\(S\)能够越大越好, 此时即负项越小越好.

定义\(\hat{G}_i \ge \|\nabla_{\theta_t} \mathcal{L}(\Psi(x_{i};\theta_t),y_{i})\|_2\), 既然



(7)式我有点困惑,我觉得(7)式右端和最小化(6)式的负项(\(\mathrm{Tr}(\mathbb{V}_{P_t}[G_{I_t}])+\|\mathbb{E}_{P_t}[G_{I_t}]\|_2^2\))是等价的.

于是有

最小化右端(通过拉格朗日乘子法)可得\(p_i \propto \hat{G}_i\), 所以现在我们只要找到一个\(\hat{G}_i\)即可.

这个部分需要引入神经网络的反向梯度的公式, 之前有讲过,只是论文的符号不同, 这里不多赘诉了.

注意\(\rho\)的计算是比较复杂的, 但是\(p_i \propto \hat{G}_i\), 所以我们只需要计算\(\|\cdot\|\)部分, 设此分布为\(g\).

另外, 在最开始的时候, 神经网络没有得到很好的训练, 权重大小相差无几, 这个时候是近似正态分布的, 所以作者考虑设计一个指标,来判断是否需要根据样本分布\(g\)来挑选样本. 作者首先衡量



显然当这部分足够大的时候我们可以采用分布\(g\)而非正态分布\(u\), 但是这个指标不易判断, 作者进步除以\(\mathrm{Tr}(\mathbb{V}_u[G_i])\).



显然\(\tau\)越大越好, 我们自然可以人为设置一个\(\tau_{th}\). 算法如下

最后, 个人认为这个算法能减少计算量主要是因为样本少了, 少在一开始用正态分布抽取了一部分, 所以...

"代码"

主要是\(\hat{G}_i\)部分的计算, 因为涉及到中间变量的导数, 所以需要用到retain_grad().

"""
这里只是一个例子
""" import torch
import torch.nn as nn class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.dense = nn.Sequential(
nn.Linear(10, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
self.final = nn.ReLU() def forward(self, x):
z = self.dense(x)
z.retain_grad()
out = self.final(z)
return out, z if __name__ == "__main__": net = Net()
criterion = nn.MSELoss() x = torch.rand((2, 10))
y = torch.rand((2, 10)) out, z = net(x)
loss = criterion(out, y)
loss.backward()
print(z.grad) #这便是我们所需要的

Not All Samples Are Created Equal: Deep Learning with Importance Sampling的更多相关文章

  1. Accelerating Deep Learning by Focusing on the Biggest Losers

    目录 概 相关工作 主要内容 代码 Accelerating Deep Learning by Focusing on the Biggest Losers 概 思想很简单, 在训练网络的时候, 每个 ...

  2. (转) The major advancements in Deep Learning in 2016

    The major advancements in Deep Learning in 2016 Pablo Tue, Dec 6, 2016 in MACHINE LEARNING DEEP LEAR ...

  3. Deep Learning in R

    Introduction Deep learning is a recent trend in machine learning that models highly non-linear repre ...

  4. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  5. [C3] Andrew Ng - Neural Networks and Deep Learning

    About this Course If you want to break into cutting-edge AI, this course will help you do so. Deep l ...

  6. Deep Learning 5_深度学习UFLDL教程:PCA and Whitening_Exercise(斯坦福大学深度学习教程)

    前言 本文是基于Exercise:PCA and Whitening的练习. 理论知识见:UFLDL教程. 实验内容:从10张512*512自然图像中随机选取10000个12*12的图像块(patch ...

  7. (转)WHY DEEP LEARNING IS SUDDENLY CHANGING YOUR LIFE

    Main Menu Fortune.com       E-mail Tweet Facebook Linkedin Share icons By Roger Parloff Illustration ...

  8. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  9. (转)The 9 Deep Learning Papers You Need To Know About (Understanding CNNs Part 3)

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About The 9 Deep Learning Papers You Need To Know Abo ...

随机推荐

  1. Spring 文档汇总

    Spring Batch - Reference Documentation Spring Batch 参考文档中文版 Spring Batch 中文文档 Table 2. JdbcCursorIte ...

  2. spring boot springMVC扩展配置 。WebMvcConfigurer ,WebMvcConfigurerAdapter

    摘要: 在spring boot中 MVC这部分也有默认自动配置,也就是说我们不用做任何配置,那么也是OK的,这个配置类就是 WebMvcAutoConfiguration,但是也时候我们想设置自己的 ...

  3. uni-app使用腾讯地图注意点

    地图map组件使用腾讯地图自定义样式: 1:在使用地图map组件腾讯地图时,获取本地定位,经纬度转地址与地址转经纬度解析时,小程序可以直接使用.但是h5版本会报跨域问题,目前前端没有找到更好的解决方法 ...

  4. Redis图形管理 redis-browser

    目录 一.介绍 二.部署 三.启动 监听单台 听多台 四.报错合集 一.介绍 redis-browser是redis的web端图形化管理工具.利用它可以查看和管理redis的数据,界面简洁,能和ral ...

  5. antd动态的表格合并(包含排序功能)

    主要是两个步骤, 1.处理接口返回数据,给其添加两个属性,一个是合并行数(列数),一个是当前数据的序号 2.在columns结合antd官网的处理方法合并表格 3.尽可能得减少计算量 数据处理函数 / ...

  6. Hive实战UDF 外部依赖文件找不到的问题

    目录 关于外部依赖文件找不到的问题 为什么要使用外部依赖 为什么idea 里面可以运行上线之后不行 依赖文件直接打包在jar 包里面不香吗 学会独立思考并且解决问题 继承DbSearcher 读取文件 ...

  7. QT QApplication干了啥?

    ------------恢复内容开始------------ QCoreApplicationPrivate 会取得current thread; 在windows平台创建TLS变量,记录线程信息,并 ...

  8. Table.Group分组…Group(Power Query 之 M 语言)

    数据源: 10列55行数据,其中包括含有重复项的"部门"列和可求和的"金额"列. 目标: 按"部门"列进行分组,显示各部门金额小计. 操作过 ...

  9. 延时间隔(Project)

    <Project2016 企业项目管理实践>张会斌 董方好 编著 像"饭前洗手"这种事,绝大部分情况下,是洗完手马上就可以动筷子开吃,但总会有意外,比如手都洗好了,突然 ...

  10. 小迪安全 Web安全 基础入门 - 第三天 - 抓包&封包&协议&APP&小程序&PC应用&WEB应用

    一.抓包工具 1.Fiddler.Fiddler是一个用于HTTP调试的代理服务器应用程序,能捕获HTTP和HTTPS流量,并将其记录下来供用户查看.它通过使用自签名证书实现中间人攻击来进行日志记录. ...