url: https://arxiv.org/abs/1503.02531

year: NIPS 2014



简介

将大模型的泛化能力转移到小模型的一种显而易见的方法是使用由大模型产生的类概率作为训练小模型的“软目标”

其中, T(temperature, 蒸馏温度), 通常设置为1的。使用较高的T值可以产生更软的类别概率分布。 也就是, 较高的 T 值, 让学生的概率分布可以更加的接近与老师的概率分布,

下面通过一个直观的例子来感受下

def softmax_with_T(logits, temperature):

    for t in temperature:
total = 0
prob = []
for logit in logits:
total += np.exp(logit/t)
for logit in logits:
prob.append(np.exp(logit/t) / total)
print('T={:<4d}'.format(t), end=' ')
for p in prob:
print('{:0.3f}'.format(p), end=' ')
print()

可以看出, softmax 输出的项比例与 logits原始比例之间的关系与 logits 本身的模长以及 T 值大小相关, 感觉 T 值需要仔细调整下, 至少能反应 logits 之间的大致关系, 而且可以看出, softmax_with_T 受两个变量的影响, 直接来比较的话, 比较难分析. 当 T 远大于 logits 的模长时, softmax 的输出尺度在相同的数量级下(如logits=[6,3,1], T=25), 这样看的话, 即使老师和学生的 logit 相差很远, 经过具有很大 T 的 softamx 之后, 数量级几乎相同, 这样是不合理的. 但是, 下面的公式推导结果加上实验结果表明, 认真看梯度才是王道, 看输出的话, 完全找不到感觉, 对于软标签交叉熵损失

梯度推导

softmax+cross entropy梯度求导

\[\bf{{\frac{\partial{C}}{\partial{z_i}} = \frac{1}{T}(q_i-p_i)
= \frac{1}{T} \left( \frac{e^{z_i/T}}{\sum_je^{z_j/T}} - \frac{e^{v_i/T}}{\sum_je^{v_j/T}}\right)}}
\]

\(e^x\)泰勒展开

\[\bf {e^x \approx 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \cdots + \frac{x^n}{n!} \\
x\rightarrow 0, \quad e^x \approx 1+x}
\]

\(T\rightarrow \infty\)时, \(\frac{Z_i}{T}\rightarrow 0\)

\[\bf {{\frac{\partial{C}}{\partial{z_i}} \approx \frac{1}{T} \left(\frac{{1+{z_i/T}}}{N+\sum {z_j/T}} -\frac{{1+{v_i/T}}}{N+\sum {v_j/T}}\right)}}
\]

假设logits已经单独进行了zero-center中心化处理,那么,

\[\bf{\sum_jz_j=\sum_jv_j=0}\\
\Downarrow \\
\bf{\frac{\partial{C}}{\partial{z_i}} \approx \frac{1}{NT^2}{(z_i-v_i)}}
\]

这样的话, 当T值最够大, 方法就变为求老师和学生的 logits 的 L2 距离了.

术语 说明
\(q^{soft}\) 老师模型的 softmax 输出软标签
\(q^{hard}\) 训练集 one-hot 硬标签
\(p^{soft}\) 学生模型的 softmax 输出软标签
\(p^{hard}\) 学生模型的 softmax 输出硬标签(T=1)

\[\bf {\text{loss_cross_entpopy} = \alpha \cdot T^2 \cdot q^{soft}\cdot \ln \left(p^\text{soft} \right) \\ \quad \quad \quad \quad + (1-\alpha) \cdot q^{hard}\cdot \ln \left(p^\text{hard} \right)}
\]

论文中发现通常给予硬标签损失函数 \(\color{red}{可忽略不计的较低权重}\) 可以获得最佳结果。 由于软目标产生的梯度的大小为 \(\frac{1}{T^2}\),因此当使用硬目标和软目标时,将它们乘以 \(T^2\) 是很重要的, 这确保软硬标签对梯度相对贡献在一个数量级。

实验结果

思考

软标签交叉熵函数与 KL 散度的联系



上式中, 由于 p 为老师的预测结果, 模型蒸馏时候, 老师模型被冻结, 从梯度反传来看, 软标签交叉熵函数 等价于 KL 散度.

对于我而言, 这篇论文相对于 Do Deep Nets Really Need to be Deep? 贡献就在于, 将 L2距离 和 KL 散度统一到一个公式中了, 由于到 T 足够大, KL 散度的梯度与 L2 距离的一样. 这篇论文中其他部分没有读懂, 没有看到其他想要的东西. 后面知识积累了有机会在看看有没有新感受吧.

蒸馏入门的话, 推荐 Do Deep Nets Really Need to be Deep? 这篇论文. 从实验分析来说, 各种分析都很到位, 分析的方式也是易读的, 容易理解. 就工程效果来看, 实际上Distilling the Knowledge in a Neural Network 这篇论文有效时候, T一般都挺大的, 那么KL 散度的实际的效果就是 L2 距离, 不如直接用 L2 距离, 理解上简单, 调节超参少, 效果也非常好.

Distilling the Knowledge in a Neural Network的更多相关文章

  1. 【DKNN】Distilling the Knowledge in a Neural Network 第一次提出神经网络的知识蒸馏概念

    原文链接 小样本学习与智能前沿 . 在这个公众号后台回复"DKNN",即可获得课件电子资源. 文章已经表明,对于将知识从整体模型或高度正则化的大型模型转换为较小的蒸馏模型,蒸馏非常 ...

  2. 【论文考古】知识蒸馏 Distilling the Knowledge in a Neural Network

    论文内容 G. Hinton, O. Vinyals, and J. Dean, "Distilling the Knowledge in a Neural Network." 2 ...

  3. 1503.02531-Distilling the Knowledge in a Neural Network.md

    原来交叉熵还有一个tempature,这个tempature有如下的定义: \[ q_i=\frac{e^{z_i/T}}{\sum_j{e^{z_j/T}}} \] 其中T就是tempature,一 ...

  4. 论文笔记:蒸馏网络(Distilling the Knowledge in Neural Network)

    Distilling the Knowledge in Neural Network Geoffrey Hinton, Oriol Vinyals, Jeff Dean preprint arXiv: ...

  5. 论文笔记之:Progressive Neural Network Google DeepMind

    Progressive Neural Network  Google DeepMind 摘要:学习去解决任务的复杂序列 --- 结合 transfer (迁移),并且避免 catastrophic f ...

  6. Recurrent Neural Network[survey]

    0.引言 我们发现传统的(如前向网络等)非循环的NN都是假设样本之间无依赖关系(至少时间和顺序上是无依赖关系),而许多学习任务却都涉及到处理序列数据,如image captioning,speech ...

  7. [Tensorflow] Cookbook - Neural Network

    In this chapter, we'll cover the following recipes: Implementing Operational Gates Working with Gate ...

  8. (zhuan) Recurrent Neural Network

    Recurrent Neural Network 2016年07月01日  Deep learning  Deep learning 字数:24235   this blog from: http:/ ...

  9. 课程一(Neural Networks and Deep Learning),第四周(Deep Neural Networks)——2.Programming Assignments: Building your Deep Neural Network: Step by Step

    Building your Deep Neural Network: Step by Step Welcome to your third programming exercise of the de ...

随机推荐

  1. Node接口实现HTTPS版的

    最近由于自己要做一个微信小程序,接口地址只能是https的,这就很难受了 于是乎,我租了个服务器,搞了个免费的ssl认证 可是呢,我不会搞https接口怎样实现 今天特意花了一天时间来学,来学习 &q ...

  2. Logstash filter 插件之 date

    使用 date 插件解析字段中的日期,然后使用该日期或时间戳作为事件的 logstash 时间戳.对于排序事件和导入旧数据,日期过滤器尤其重要.如果您在事件中没有得到正确的日期,那么稍后搜索它们可能会 ...

  3. kettle抽取数据发送邮件Linux调度

    kettle抽取数据发送邮件Linux调度 #1.进入kettle安装目录 然后执行sqoop.sh文件启动kettlecd /app/pdi-ce-7.1.0.0-12/data-integrati ...

  4. Python 从入门到进阶之路(六)

    之前的文章我们简单介绍了一下 Python 的面向对象,本篇文章我们来看一下 Python 中异常处理. 我们在写程序时,有可能会出现程序报错,但是我们想绕过这个错误执行操作.即使我们的程序写的没问题 ...

  5. C#以对象为成员的例子

    using System; using System.Collections.Generic; using System.Text; namespace test { class Program { ...

  6. (转)RocketMQ工作原理

    原文:https://blog.csdn.net/lyly4413/article/details/80838716 1.消息中间件的发展: 第一代以ActiveMQ为代表,遵循JMS(java消息服 ...

  7. SpringBoot启动项目时提示:Error:(3, 32) java: 程序包org.springframework.boot不存在

    场景 在IDEA中新建SpringBoot项目,后启动项目时提示: Error:(3, 32) java: 程序包org.springframework.boot不存在 实现 将pom.xml中par ...

  8. 上传图片到七牛云(客户端 js sdk)

    大体思路 上一篇我们讲了如何通过服务器生成一个upToken,那前端拿到这个token后又该如何操作?在这里我给出一个相当简洁的版本. 首先我们来看一下上传的思路:调用七牛模块的upload方法,生成 ...

  9. C# 新特性 操作符单?与??和 ?. 的使用

    1.单问号(?) 1.1 单问号运算符可以表示:可为Null类型,C#2.0里面实现了Nullable数据类型 //A.比如下面一句,直接定义int为null是错误的,错误提示为无法将null转化成i ...

  10. 关于如何获取项目所部署的本机IP和端口的问题

    关于如何获取项目所部署的本机IP和端口的问题 今天在写一个需求的时候碰到一个不常见的问题,在没有继承或者实现服务器提供的接口或者实现类的时候,比如说部署在tomacat上,某个类不去继承servelt ...