半监督学习(Semi-Supervised Learning,SSL)的 SOTA 一次次被 Google 刷新,从 MixMatch 开始,到同期的 UDA、ReMixMatch,再到 2020 年的 FixMatch。

这四篇深度半监督学习方面的工作,都是从 consistency regularization 和 entropy minimization 两方面入手:

  • consistency regularization:一致性,给输入图片或者中间层注入 noise,模型的输出应该尽可能保持不变或者近似。
  • entropy minimization:最小化熵,模型在 unlabeled data 上的熵应该尽可能最小化。Pseudo label 也隐含地用到了 entropy minimization。

Consistency Regularization

对于每一个 unlabeled instance,consistency regularization 要求两次随机注入 noise 的输出近似。背后的思想是,如果一个模型是鲁棒的,那么即使输入有扰动,输出也应该近似。

对于 consistency regularization 来说,如何注入 noise 以及如何计算近似,就是每个方法的不同之处。注入 noise 可以通过模型本身的随机性(如 dropout)或者直接加入噪声(如 Gaussian noise),也可以通过 data augmentation;计算一致性的方法,可以使用 L2,也可以使用 KL divergency、cross entropy。

Entropy Minimization

MixMatch、UDA 和 ReMixMatch 通过 temperature sharpening 来间接利用 entropy minimization,而 FixMatch 通过 Pseudo label 来间接利用 entropy minimization。可以认为,只要通过得到 unlabeled data 的人工标签然后按照监督学习的方法(如 cross entropy loss)来训练的,都间接用到了 entropy minimization。因为人工标签都是 one-hot 或者近似 one-hot 的,如果 unlabeled data 的 prediction 近似人工标签,那么此时无标签数据的熵肯定也是较小的。

为什么这里使用人工标签而不是伪标签的称呼?一般而言,在半监督中,伪标签(pseudo label)特指 hard label,即 one-hot 类型的或者通过 argmax 得到的。[4]

Entropy minimization 可以在计算 unlabeled data 部分的 loss 和 consistency regularization 一起实现。

temperature sharpening 和 pseudo label 都得到了 unlabeled data 的人工标签,当前者 temperature=0 时,两者相等。pseudo label 要比 temperature sharpening 要简单,因为少了一个 temperature 超参数。

如果不利用 entropy minimization,那么 temperature sharpening 和 pseudo label 其实都是不需要的,只需要两次随机注入 noise 的 unlabeled instance 输出近似,就可以保证 consistency regularization。

或者说,得到 unlabeled data 的人工标签,可以使得 entropy minimization 和 consistency regularization 通过一项 loss 来完成。

结合 Consistency Regularization 和 Entropy Minimization

一般来说,半监督学习中的 unlabeled data 会使用全部训练数据集,即有标签的样本也会作为无标签样本来使用。

半监督学习中,labeled data 的标签都是给定的,而 unlabeled data 的标签都是不知道的。那么如何获得 unlabeled data 的人工标签(artificial label),MixMatch、UDA、ReMixMatch 和 FixMatch 的做法或多或少都不相同:

  • MixMatch:平均 K 次 weak augmentation(如 shifting 和 flipping)的 predictions ,然后经过 temperature sharpening;
  • UDA:一次 weak augmentation 的 prediction,然后经过 temperature sharpening;
  • ReMixMatch:一次 weak augmentation 的 prediction,然后经过 distribution alignment,最后经过 temperature sharpening;
  • FixMatch:一次 weak augmentation 的 prediction,然后 one-hot 得到 hard label。

得到了人工标签,我们就可以按照监督学习的方式来训练,这种思考方式就利用了 entropy minimization。而从 unlabeled data 的 consistency regularization 角度思考,我们需要注入不同的 noise,使得 unlabeled data 的 predictions 和它们的人工标签一致。

MixMatch、UDA、ReMixMatch 和 FixMatch 都利用 data augmentation 改变输入样本来注入 noise,不同的是 data augmentation 的具体方式和强度:

  • MixMatch:一次 weak augmentation 得到 prediction,这就和正常的监督训练一样,只是 unlabeled loss 用的是 L2 而已;
  • UDA:一次 strong augmentation(RandAugment) 得到 prediction;
  • ReMixMatch:多次 strong augmentation(CTAugment)得到 predictions,然后同时参与 unlabeled loss 的计算,即一个 unlabeled instance 一个 step 多次增强后计算多次 loss;
  • FixMatch:一次 strong augmentation(RandAugment 或 CTAugment)得到 prediction。

从 UDA 和 ReMixMatch 开始,strong augmentation 引入了半监督训练。UDA 使用了作者之前提出的 RandAugment 的 strong augmentation 方式,而 ReMixMatch 提出了一种 CTAugment。FixMatch 就把 UDA 和 ReMixMatch 中用到的 strong augmentation 都拿来用了一遍。

对于 unlabeled data 部分的 loss:

  • MixMatch:L2 loss;
  • UDA:KL divergency;
  • ReMixMatch:cross entropy(包括自监督的 rotation loss 和没有使用 mixup 的 pre-mixup unlabeled loss);
  • FixMatch:带阈值的 cross entropy。

FixMatch: Simplifying SSL with Consistency and Confidence

FixMatch 简化了 MixMatch、UDA 和 ReMixMatch,然后获得了更好的效果:

  • 首先,temperature sharpening 换成 pseudo label,这是一个简化;
  • 其次,FixMatch 通过设定一个阈值,在计算 unlabeled loss 时,对 prediction 的 confidence 超过阈值的 unlabeled instance 才算入 unlabeled loss,这样使得 unlabeled loss 的权重可以固定,这是第二个简化。

References

[1] Berthelot, D., Carlini, N., Goodfellow, I., Papernot, N., Oliver, A., Raffel, C. (2019). MixMatch: A Holistic Approach to Semi-Supervised Learning arXiv https://arxiv.org/abs/1905.02249

[2] Berthelot, D., Carlini, N., Cubuk, E., Kurakin, A., Sohn, K., Zhang, H., Raffel, C. (2019). ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring arXiv https://arxiv.org/abs/1911.09785

[3] Xie, Q., Dai, Z., Hovy, E., Luong, M., Le, Q. (2019). Unsupervised Data Augmentation for Consistency Training arXiv https://arxiv.org/abs/1904.12848

[4] Sohn, K., Berthelot, D., Li, C., Zhang, Z., Carlini, N., Cubuk, E., Kurakin, A., Zhang, H., Raffel, C. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence arXiv https://arxiv.org/abs/2001.07685

【半监督学习】MixMatch、UDA、ReMixMatch、FixMatch的更多相关文章

  1. 利用DP-SSL对少量的标记样本进行有效的半监督学习

    作者 | Doreen 01 介绍 深度学习之所以能在图像分类.自然语言处理等方面取得巨大成功的原因在于大量的训练数据得到了高质量的标注. 然而在一些极其复杂的场景(例如:无人驾驶)中会产生海量的数据 ...

  2. 基于PU-Learning的恶意URL检测——半监督学习的思路来进行正例和无标记样本学习

    PU learning问题描述 给定一个正例文档集合P和一个无标注文档集U(混合文档集),在无标注文档集中同时含有正例文档和反例文档.通过使用P和U建立一个分类器能够辨别U或测试集中的正例文档 [即想 ...

  3. sklearn半监督学习

    标签: 半监督学习 作者:炼己者 欢迎大家访问 我的简书 以及 我的博客 本博客所有内容以学习.研究和分享为主,如需转载,请联系本人,标明作者和出处,并且是非商业用途,谢谢! --- 摘要:半监督学习 ...

  4. python大战机器学习——半监督学习

    半监督学习:综合利用有类标的数据和没有类标的数据,来生成合适的分类函数.它是一类可以自动地利用未标记的数据来提升学习性能的算法 1.生成式半监督学习 优点:方法简单,容易实现.通常在有标记数据极少时, ...

  5. 吴裕雄 python 机器学习——半监督学习LabelSpreading模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import metrics from sklearn import d ...

  6. 吴裕雄 python 机器学习——半监督学习标准迭代式标记传播算法LabelPropagation模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import metrics from sklearn import d ...

  7. 虚拟对抗训练(VAT):一种用于监督学习和半监督学习的正则化方法

    正则化 虚拟对抗训练是一种正则化方法,正则化在深度学习中是防止过拟合的一种方法.通常训练样本是有限的,而对于深度学习来说,搭设的深度网络是可以最大限度地拟合训练样本的分布的,从而导致模型与训练样本分布 ...

  8. 【论文解读】【半监督学习】【Google教你水论文】A Simple Semi-Supervised Learning Framework for Object Detection

    题记:最近在做LLL(Life Long Learning),接触到了SSL(Semi-Supervised Learning)正好读到了谷歌今年的论文,也是比较有点开创性的,浅显易懂,对比实验丰富, ...

  9. AI之强化学习、无监督学习、半监督学习和对抗学习

    1.强化学习 @ 目录 1.强化学习 1.1 强化学习原理 1.2 强化学习与监督学习 2.无监督学习 3.半监督学习 4.对抗学习 强化学习(英语:Reinforcement Learning,简称 ...

随机推荐

  1. org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.binding.BindingException: Parameter 'employeeId' not found. Available parameters are [page, map, param1, param2] 解决方法

    原因很简单就是没映射到接口添加 @Param 注解 ->@Param("map") 然后在mapper.xml map.employeeId 再次测试 已经解决 ->

  2. Hive常用命令及作用

    1-创建表 -- 内部表 create table aa(col1 string,col2 int) partitioned by(statdate int) ROW FORMAT DELIMITED ...

  3. 从ISTIO熔断说起-轻舟网关熔断

    最近大家经常被熔断洗脑,股市的动荡,让熔断再次出现在大家眼前.微服务中的熔断即服务提供方在一定时间内,因为访问压力太大或依赖异常等原因,而出现异常返回或慢响应,熔断即停止该服务的访问,防止发生雪崩效应 ...

  4. 【SQL SERVER】锁机制

    锁定是 SQL Server 数据库引擎用来同步多个用户同时对同一个数据块的访问的一种机制. 基本概念 利用SQL Server Profiler观察锁 死锁产生的原因及避免 总结 基本概念 数据库引 ...

  5. (note)从小白到产品经理之路

    学习了云课堂的产品课程,整理出部分笔记,以作备用参考,方便实际运用过程中查看巩固. 1.产品工具:Axure.mindmanager.viso.办公软件wps 2.产品人需要具备的品格 富有同理心,习 ...

  6. 事务框架之声明事务(自动开启,自动提交,自动回滚)Spring AOP 封装

    利用Spring AOP 封装事务类,自己的在方法前begin 事务,完成后提交事务,有异常回滚事务 比起之前的编程式事务,AOP将事务的开启与提交写在了环绕通知里面,回滚写在异常通知里面,找到指定的 ...

  7. 《带你装B,带你飞》pytest成魔之路4 - fixture 之大解剖

    1. 简介 fixture是pytest的一个闪光点,pytest要精通怎么能不学习fixture呢?跟着我一起深入学习fixture吧.其实unittest和nose都支持fixture,但是pyt ...

  8. Spring中的设计模式:工厂方法模式

    导读 工厂方法模式是所有设计模式中比较常用的一种模式,但是真正能搞懂用好的少之又少,Spring底层大量的使用该设计模式来进行封装,以致开发者阅读源代码的时候晕头转向. 文章首发于微信公众号[码猿技术 ...

  9. extend()和append()的区别

    append()方法用于在列表末尾添加新的对象(对象可以是值或列表),一般用于添加列表项. extend()方法用于在列表末尾追加另一个序列中的多个值.

  10. CCF2018 12 2题,小明终于到家了

    最近在愁着备考,拿CCF刷题,就遇到这个难题,最后搜索了一下大佬们的方法,终于解决, 问题描述 一次放学的时候,小明已经规划好了自己回家的路线,并且能够预测经过各个路段的时间.同时,小明通过学校里安装 ...