【论文阅读】Deep Mutual Learning
文章:Deep Mutual Learning
出自CVPR2017(18年最佳学生论文)
文章链接:https://arxiv.org/abs/1706.00384
代码链接:https://github.com/YingZhangDUT/Deep-Mutual-Learning
主要贡献:
提出了一种简单且普遍适用的方法,通过在相同/不同的未预训练的网络中进行相互蒸馏,来提高深层神经网络的性能。通过这种方法,我们可以获得比静态教师从强网络中提取的网络性能更好的紧凑网络.
和有教师指导的蒸馏模型相比,相互学习策略具有以下优点:1)随着学生网络的增加其效率也得到提高;2)它可以应用在各种各样的网络中,包括大小不同的网络;3)即使是非常大的网络采用相互学习策略,其性能也能够得到提升
由于是学生网络相互学习,而不是传统知识萃取,文章也说明了两个以上网络共同学习的策略,并从熵值的角度给出理论支持。
知识蒸馏的内容不再赘述,https://blog.csdn.net/nature553863/article/details/80568658
整理得非常完善。
网络结构及损失函数:

每个网络由常规的有监督学习损失和拟态损失来共同训练。拟态损失是指是每个学生的后验类别要和其他学生的类别概率相一致。(拟态在生物学中是指一个动物在进化的过程中会获得与成功物种相似的特征,成功混淆掠食者的认知,从而靠近拟态物种)
如图一所示,同个输入,分别经过Net1和Net2两个网络(和孪生网络不同,这里权重不共享),得到两个logits记为z1,z2(所说文章中有明确提到这里的z是softmax之后的,但是我整的时候取得却是softmax之前的),经过softmax得到预测的软分布p1,p2。送入KL散度计算这两个分布的相似性作为拟态损失。与label比较计算标签损失。

注意,上式表示的是:假设p2是数据的真实分布,p1是数据的理论分布,即是让p1的分布更接近p2的分布。这个别搞反了,毕竟KL散度有几个性质:(1)不对称性:尽管KL散度从直观上是个度量或距离函数,但它并不是一个真正的度量或者距离,因为它不具有对称性,即D(P||Q)!=D(Q||P)。(恩,还有非负和不满足三角不等,这两个在这没用就不写了)。所幸,KL散度求出的值是在0-1之间(分布完全相同则为0),比一般的L1,L2损失更适合拟合分布的损失。但是,如果两个分配P,Q离得很远,完全没有重叠的时候,那么KL散度值是没有意义的,这在学习算法中是比较致命的,这就意味这这一点的梯度为0。梯度消失了。(说实话,文中没有给出其他度量下的结果,我试了KL变为JS,L1,L2这几种分布度量,结果十分相近,KL或许不是最合适的)
网络Net1的总损失为:
其中:
网络Net2的总损失为:
优化步骤:
相互学习策略在每一个基于小批量的模型更新步骤和整个培训过程中执行。在每次迭代中,我们计算两个模型的预测,并根据另一个模型的预测更新两个网络的参数。两个网络的优化是迭代进行的,直到收敛。优化细节总结在算法1中。

恩,这个很简单,就不翻译了。就是固定一个网络更新另一个,循环交替直至收敛。
扩展到更多的网络(理论上你想要几个,只要机器够牛逼都行):
提出的相互学习策略可以扩展到更多的学生网络,假设有K个学生网络,其损失目标函数变为:

添加了系数k-1,以确保培训主要由对真实标签的监督学习指导.拟态损失为与其他所有网络概率分布相似性的平均值。恩,之后的实验都默认使用这一种哈。
另一个学习策略是将所有其他K-1网络的集合作为单个教师,以提供平均的分布概率,这非常类似于蒸馏方法,θk的目标函数可以写成:

拟态损失为与其他所有网络概率分布平均值的相似性。
~_~。这个大家可以类比多分类问题中一个多分类器与多个二分类器。大家可能下意识的以为后一种策略会有更高的精度,但事实恰恰相反,文章最后倒是有给出理由,后面再说吧。
实验部分:
两个数据集,一个是cifar100,一个是行人再检测的maket1501。参数设置什么的文章有给,我不列出来了,反正只要你设定的不是太离谱,都可以得到好结果。

Independent表示单个网络运行结果,DML-Independent,恩,就是DML的结果减去Independent的结果,普遍提升1.5个点(机器问题只试了resnet的,的确能提升这么多)。

这是和知识蒸馏进行对比,当然了,是最原始的histon2014的那版。

这个,算消融实验,只是结果太好了,作者就和其他方法进行了对比。single-query表示单索引,multi-query表示多索引,百度下或问做再检测的朋友是啥子。

这个是扩展到N个网络的结果,恩,之前不是说了两种扩展方法吗,这里都是第一种。左边这个图呢,不是有N个网络嘛,N个网络分别得到N个精度,然后取平均,就得到左图了(应该是凑字数用的吧,应该没人会这么做吧,好浪费的,投票不行吗)。右图表示基于所有成员的连接特征进行匹配,精度明显随网络的增多而增加。
理论:
最后作者回答了为什么这样联合学习能提升精度。《Entropy-sgd:
Biasing gradient descent into wide
valleys》这篇文章大家有兴趣的可以去瞅瞅,个人认为解释了我的某些疑惑。通常情况下有很多的解决方案能够让训练误差变为0,然而有些解决方法的泛化能力要强一点。因为梯度下降法找打的最优点不是在狭小的谷底内,而是在宽阔的峡谷中,当我们加入相对熵以后能够让网络找到更小的值,从而实现更优的结果。

对于DML模型和独立模型,比较了在每个模型参数中加入独立高斯噪声和可变标准差σ后,学习模型的训练损失。发现两个损失基本是相同的,但是在加入这个扰动后,独立模型的训练损失相比DML模型的损失会增加得多。这表明DML模型已经找到了一个更广泛的极小值,这有望提供更好的泛化性能。
在提出的DML策略中,每个学生都是由队列中的所有其他学生单独教授的,不管队列中有多少学生。一种替代的DML策略称为ensemble策略,要求每个学生匹配队列中所有其他学生的联合预测。人们可以合理地期望这种ensemble方法会更好。由于集合预测比单个预测更好,因此它应该提供一个更清晰、更强的教学信号——更像常规蒸馏。在实践中,联合教学(ensemble)比依次单独教学(peer)的效果更差。通过对联合ensemble教学信号的分析,与依次单独peer教学相比,ensemble目标在真标签上的峰值要比peer目标明显得多,从而使DML的预测熵值大于DML-E。因此,虽然集成的噪声平均特性对于做出正确的预测是有效的,但它实际上不利于提供一个教学信号,其中secondary probabilities在总体信号中是突出的,具有高熵的后验代表这模型训练的更鲁棒的解决方案。(大家联想一下软标签与硬标签的不同。)

当然,文中也有说明How
a Better Minima is Found,但是没有看明白,对前置知识(《Entropy-sgd:
Biasing gradient descent into wide valleys》与《Regularizing
neural networks by penalizing confident output
distributions》)没有足够的了解,翻译一下原文:当要求每个网络匹配其对等网络的概率估计时,如果给定网络预测为零,而其教师/对等网络预测为非零,则会受到严重惩罚。因此,DML的总体效果是,当每个网络独立地将一个小的mass放在一组小的secondary
probabilities上时,DML中的所有网络都倾向于聚合它们对secondary
probabilities的预测,i)将更多的mass放在secondary
probabilities上,以及ii)将非零mass置于更明显的secondary
probabilities上。通过比较图中由DML培训的CIFAR-100上的resnet-32获得的top5与独立培训的resnet-32模型的概率来说明这一影响。对于每个训练样本,根据模型产生的后验概率对top5进行排序。在这里,我们可以看到,对于独立学习来说,将mass分配到top1下的概率比DML学习要快得多。这可以用DML训练模型和独立训练模型的所有训练样本的平均熵值来量化,分别为1.7099和0.2602。因此,我们的方法与基于熵正则化的方法有联系,以寻找宽的极小值,但通过对“合理”的选择的相互概率匹配,而不是盲目的高熵偏好。

写在最后:个人认为适用范围挺广的,效果提升也很明显,也还留了许多可以改进的点。知识蒸馏的想法是通过教师网络提供的软标签,给予学生网络硬标签所不能表达的新的信息量(类似与猫更像狗而不是蛋糕)。而文中说的联合学习的方式也提供了更多的信息量。或许有的同学会认为,这是学习了更鲁棒的特征,而不是蕴含了更多的信息量,恩,不相信的朋友不妨试一下把KL(logits)换成KL(features)看下结果。全连接fc之后为logits,之前为features。
【论文阅读】Deep Mutual Learning的更多相关文章
- [论文阅读] Deep Residual Learning for Image Recognition(ResNet)
ResNet网络,本文获得2016 CVPR best paper,获得了ILSVRC2015的分类任务第一名. 本篇文章解决了深度神经网络中产生的退化问题(degradation problem). ...
- Paper | Deep Mutual Learning
目录 1. 动机详述和方法简介 2. 相关工作 3. 方法 3.1 Formulation 3.2 实现 3.3 弱监督学习 4. 实验 4.1 基本实验 4.2 深入实验 [算法和公式很simple ...
- [论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks
[论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks 本文结构 解决问题 主要贡献 算法原理 参考文献 (1) 解决问 ...
- Deep Mutual Learning
论文地址: https://arxiv.org/abs/1706.00384 论文简介 该论文探讨了一种与模型蒸馏(model distillation)相关却不同的模型---即相互学习(mutual ...
- 论文笔记——Deep Residual Learning for Image Recognition
论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...
- [论文理解]Deep Residual Learning for Image Recognition
Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...
- 论文阅读:Multi-task Learning for Multi-modal Emotion Recognition and Sentiment Analysis
论文标题:Multi-task Learning for Multi-modal Emotion Recognition and Sentiment Analysis 论文链接:http://arxi ...
- 论文阅读《End-to-End Learning of Geometry and Context for Deep Stereo Regression》
端到端学习几何和背景的深度立体回归 摘要 本文提出一种新型的深度学习网络,用于从一对矫正过的立体图像回归得到其对应的视差图.我们利用问题(对象)的几何知识,形成一个使用深度特征表示的代价量(c ...
- 论文阅读 Inductive Representation Learning on Temporal Graphs
12 Inductive Representation Learning on Temporal Graphs link:https://arxiv.org/abs/2002.07962 本文提出了时 ...
随机推荐
- lamda和匿名内部类
匿名内部类 匿名内部类在日常编程中还是经常会使用的.比如 ArrayList<String> list=new ArrayList<>(); list.add(new Stri ...
- Robot Framework——对时间操作的datetime库常用关键字
1.对固定日期进行操作,增加或减去单位时间或者时间段 2.对两个时间段进行操作 3.对时间格式转化,获取时间戳 4.从完整时间中取指定年月日等 5.对时间类型进行格式化 6.获取当前时间或者指定时区时 ...
- Leetcode Tags(6)Math
一.204. Count Primes Count the number of prime numbers less than a non-negative number, n. Input: 10 ...
- 【Spring Boot】java.lang.NoSuchMethodError: org.springframework.web.util.UrlPathHelper.getLookupPathForRequest(Ljavax/servlet/http/HttpServletRequest;Ljava/lang/String;)Ljava/lang/String;
Digest:今天Spring Boot 应用启动成功,访问接口出现如下错误,不知到导致问题关键所在,记录一下这个问题. 浏览器报500错误 项目代码如下 Controller.java packag ...
- bzoj1004 card
明知是burnside然而根本不会然后无耻地颓了题解后一脸傻气的我: 直接套公式???为啥方案数==等价类数量啊??? skyh:显然啊(狂笑)(hey wxy!他问为啥方案书等于等价类数量!) wx ...
- 腾讯新闻构建高性能的 react 同构直出方案
在腾讯新闻抢金达人活动 node 同构直出渲染方案的总结文章中我们整体了解了下同构直出渲染方案在我们项目中的使用.正如我在上篇文章结尾所说的: 应用型技术的难点不是在克服技术问题,而是在于能够不断的结 ...
- JavaScript权威指南----一个JavaScript贷款计算器
废话不多说上例子代码: <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> & ...
- [springboot 开发单体web shop] 5. 用户登录及首页展示
用户登录及前端展示 用户登录 在之前的文章中我们实现了用户注册和验证功能,接下来我们继续实现它的登录,以及登录成功之后要在页面上显示的信息. 接下来,我们来编写代码. 实现service 在com.l ...
- 学习 Java 应该关注哪些网站?
经常有一些读者问我:"二哥,学习 Java 应该关注哪些网站?",我之前的态度一直是上知乎.上搜索引擎搜一下不就知道了.但读者对我这个态度很不满意,他们说,"我在问你,又 ...
- MapReduce任务提交源码分析
为了测试MapReduce提交的详细流程.需要在提交这一步打上断点: F7进入方法: 进入submit方法: 注意这个connect方法,它在连接谁呢?我们知道,Driver是作为客户端存在的,那么客 ...