这篇论文:

  1. 提出了prob-and-allocate训练策略,在prob阶段获得样本损失,在allocate阶段分配样本权重。
  2. 以[2]的meta-weight-net为Baseline,取名为CurveNet,进行部分改动。

另外,这篇论文提供的源码结构混乱,复现难度较大。主要的工作也是基于meta-weight-net,创新的内容有限。但是,这篇文章在Introduction对long-tailed data + noisy labels问题的描述非常清晰。

Introduction

Background

分别单独处理long-tailed data和noisy labels的数据偏置时,re-weighting策略是常见且有效的办法:通过loss值分配相应的权重。但如果两类偏置同时出现,re-weighting效果就不佳了。

具体来说,对于包含noisy labels的训练数据,noisy labels的样本往往具有较大的训练损失,因此加权函数应该将大损失映射到小样本权重,以减轻标签噪声的影响。

对于类别不平衡的训练数据,尾类样本通常会因训练不足而遭受较大损失,因此加权函数应该为这些硬正样本分配较大的权重,使网络更加强调尾类以提高整体性能表现。

tailed类loss大,但从noisy labels角度看是噪声标签,需要小权重;从imbalance角度看,需要大权重。

处理两类偏置的关键问题:区分尾部类别的干净样本和标签损坏的样本。

Motivation

观察Figure 1右图。噪声样本的损失在训练开始时保持稳定,而干净样本的损失在开始时急剧上升,然后迅速下降。因此,训练损失曲线实际上包括了有价值的信息,并且可以提供有用的先验来区分尾类的干净样本和噪声样本。

虽然文中提出的CurveNet与meta-weight-net参数更新方法几乎一致,但meta-weight-net论文中,meta-weight-net的输入(也就是loss)随着训练而变化,无法代表样本的整体训练状态。

此外,meta-weight-net只能处理单一的数据偏置,两类偏置一起的情况原文中并未对其测试。而改进的CurveNet可以同时处理两类偏置。

Method

Meta-weight-net部分

在Method一节中,meta-weight-net的部分内容占了几乎半页。。。称之为“Revisiting Meta-Weight-Net”。

  • meta-weight-net基于MLP提出了分类网络\(\mathcal{F}\),其中参数记为\(\omega\)。
  • 带偏置的training-data: \(\mathcal{D}^{tra}=\{x_{i}^{tra},y_{i}^{tra}\}_{i=1}^{N}\),无偏置的meta-data: \(\mathcal{D}^{meta}=\{x_{i}^{meta},y_{i}^{met\hat{a}}\}_{i=1}^{M}\),N,M表示样本数,\(N\gg M\), \(X, Y\)分别表示数据和标签。

meta-data作为无偏置数据来自验证集,有点像Zero-shot的Transductive设置。这种设置我到现在还是觉得莫名其妙。

对于传统的训练,分类器的参数训练通过最小化损失:

\[\omega^*=\arg\min_\omega\mathcal{L}(Y^{tra},\mathcal{F}(X^{tra}|\omega)),\tag{1}
\]

\(\mathcal{F}\)一般是卷积神经网络。接下来为了简化,我们令\(\mathcal{L}_{tra}=\mathcal{L}(Y^{tra},\mathcal{F}(X^{tra}|\omega))\)。然而,数据存在偏置时,公式1可能不能很好地优化参数。这时需要采用re-weighting策略,对损失施加权重\(\mathcal{G}(\mathcal{L}_{tra}|\Theta)\),\(\mathcal{G}\)是输出权重的网络,\(\Theta\)为该网络参数。此时,公式1变为:

\[\omega^*=\arg\min_\omega\mathcal{G}(\mathcal{L}_{tra}|\Theta)\mathcal{L}_{tra}.\tag{2}
\]

具体来说\(\mathcal{G}\)为一个仅含1个隐藏层的MLP,含100个神经元节点,以Sigmoid为激活函数,输出区间为[0,1]。通过元学习进行参数优化:

\[\Theta^*=\underset{\Theta}{\operatorname*{\arg\min}}\mathcal{L}(Y^{meta},\mathcal{F}(X^{met\boldsymbol{a}}|\omega^*(\mathcal{G}(\Theta)))).\tag{3}
\]

总觉得这里和元学习没啥关系。

式3的损失函数用\(\mathcal{L}_{meta}\)表示。由于两种参数\(\omega, \Theta\)都需要更新,所以需要分开更新,更新一种参数时令另一种参数为已知量。

  1. 先更新\(\omega\),这里的\(\omega\)作为临时更新参数,t为当前epoch:
\[\hat{\omega}^t=\omega^t-\alpha\bigtriangledown_{\omega}\mathcal{G}(\mathcal{L}_{tra}^t|\Theta^t)\circ\mathcal{L}_{tra}^t|_{\omega^t},\tag{4}
\]
  1. 临时更新的\(\omega\)用来更新\(\Theta\),更新完就可以丢弃:
\[\Theta^{t+1}=\Theta^t-\beta\bigtriangledown_{\Theta}\mathcal{L}_{meta}^t(\hat{\omega}^t(\Theta^t))|_{\Theta^t}.\tag{5}
\]
  1. 再用更新后的\(\Theta\)更新真正的\(\omega\):
\[\omega^{t+1}=\omega^t-\alpha\bigtriangledown_{\omega}\mathcal{G}(\mathcal{L}_{tra}^t|\Theta^{t+1})\circ\mathcal{L}_{tra}^t|_{\omega^t}.\tag{6}
\]

以上都是作者照搬了meta-weight-net的内容,作者总结了meta-weight-net的缺陷:

  1. meta-weight-net采用当前损失值作为输入,该损失值在整个训练过程中发生巨大变化,并且无法代表样本的状态。
  2. 损失值在每个epoch都不同,并且在训练过程中变得越来越小,这不利于(用于分类的)网络收敛。
  3. 当噪声和tail class 样本呢同时存在时,权重可能很大也可能很小,导致分类器的性能不理想。

作者以此为motivation,提出了prob-and-allocate训练策略,不再随着权重赋值,而是先统一收集损失,在分配权重。

CurNet

把第i个样本,T个epoch内的损失收集起来:\(L_i=[l_{i,0},l_{i,1},\cdots,l_{i,T}]\),由于初始参数随机产生,可以移除前S个损失,结果变为:\(L_i=[l_{i,S},l_{i,S+1},\cdots,l_{i,T}]\)。

对于同一类,计算loss的均值:

\[\mu_{k,t}=\frac{\sum_j^N\mathbb{1}(k,y_j)l_{j,t}}{\sum_j^N\mathbb{1}(k,y_j)},\tag{7}
\]

接下来,对每个类的样本,减去类内的均值:

\[\bar{l}_{i,t}=l_{i,t}-\mu_{y_i,t}.\tag{8}
\]

这里\(k(1\le k\le K)\)表示class,\(\mathbb{1}\)表示Dirac delta函数,输入的两个变量相等时输出1,否则输出0。

归一化损失向量可以表示为 I,然后依次馈送到全连接层,每个层都耦合到 ReLU 激活层。 P为最后一个全连接层的输出神经元数量,这里通过实验设置为64。

作为进一步促进噪声识别的一种方法,我们采用类标签嵌入方法将类信息丰富到损失曲线特征中。这种嵌入方法在自然语言处理领域常用(Cao et al. 2021),这里的嵌入矩阵可以表示为:\(Y^{K\times P}=[Y_{1},\cdots,Y_{K}].\)。再把I 和 Y连接并输入到MLP中。

优化的时候忽略几层

为了加速\(\Theta\)优化,根据FaMUS (Xu et al. 2021)\(\triangledown_\Theta\mathcal{L}_{meta}^t|_{\Theta^t}\)重写为:

\[\bigtriangledown_\Theta\mathcal{L}_{meta}^t|_{\Theta^t} =\frac{\partial\mathcal{L}_{meta}^t}{\partial\hat{\omega}^t}\bullet\frac{\partial\hat{\omega}^t}{\partial\mathcal{G}(\Theta^t)}\bullet\frac{\partial\mathcal{G}(\Theta^t)}{\partial\Theta^t}
\propto\sum_i^Z\frac{\partial\mathcal{L}_{meta}^t}{\partial\hat{\omega}_i^t}\bullet\frac{\partial\hat{\omega}_i^t}{\partial\mathcal{G}(\Theta^t)}\bullet\frac{\partial\mathcal{G}(\Theta^t)}{\partial\Theta^t},
\tag{9}\]

Z表示分类器的层数。然后冻结SL层,再更新\(\Theta\),此时式9变为:

\[\bigtriangledown_\Theta\mathcal{L}_{meta}^t|_{\Theta^l}\propto\sum_{i=SL}^Z\frac{\partial\mathcal{L}_{meta}^t}{\partial\hat{\omega}_i^t}\bullet\frac{\partial\hat{\omega}_i^t}{\partial\mathcal{G}(\Theta^t)}\bullet\frac{\partial\mathcal{G}(\Theta^t)}{\partial\Theta^t} \tag{10}
\]

这种做法感觉没什么用,作者的消融实验也证实了这一点,此外,作者的没有比较整体的训练时间,因此这个消融实验的结果说服力欠佳。

改变了输入的式4和式6

\[\hat{\omega}^t=\omega^t-\alpha\bigtriangledown_{\omega}\mathcal{G}([I,Y^{tra}]|\Theta^t)\circ\mathcal{L}_{tra}^t|_{\omega^t}\tag{11}
\]
\[\omega^{t+1}=\omega^t-\bigtriangledown_{\omega}\alpha\mathcal{G}([I,Y^{tra}]|\Theta^{t+1})\circ\mathcal{L}_{tra}^t|_{\omega^t.}\tag{12}
\]

整体的训练框架如下:

当学习率改变时,不同类别样本的损失值曲线存在显着差异。因此,在探测阶段采用循环学习率(Smith 2017)来训练分类器\(\mathcal{F}(\omega)\),O2U-Net 也采用了这种方法(Huang et al. 2019)。此外,当分类器的学习率降低时,认为CurveNet已经优化得很好,不再更新CurveNet的参数以加快训练速度。

Experiments

实验部分没有什么亮点,作者主要与Baseline meta-weight-net进行对比。作者放了一张不同类的参数权重与epoch的关系:

可观察到:

  1. clear样本权重大于noisy样本;
  2. 尾部类的权重显著大于头部类

这种结果确实是理想的情况,证明了损失曲线有着有效信息。但观察尾部类的3张图,它们的权重还是靠的有点近,不确定作者的方法在尾部类的精度上如何。

参考文献

  1. Jiang, Shenwang, et al. "Delving into sample loss curve to embrace noisy and imbalanced data." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 36. No. 6. 2022.
  2. Shu, Jun, et al. "Meta-weight-net: Learning an explicit mapping for sample weighting." Advances in neural information processing systems 32 (2019).

Delving into Sample Loss Curve to Embrace Noisy and Imbalanced Data的更多相关文章

  1. caffe的python接口学习(7):绘制loss和accuracy曲线

    使用python接口来运行caffe程序,主要的原因是python非常容易可视化.所以不推荐大家在命令行下面运行python程序.如果非要在命令行下面运行,还不如直接用 c++算了. 推荐使用jupy ...

  2. tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    # Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...

  3. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  4. Caffe---Pycaffe 绘制loss和accuracy曲线

    Caffe---Pycaffe 绘制loss和accuracy曲线 <Caffe自带工具包---绘制loss和accuracy曲线>:可以看出使用caffe自带的工具包绘制loss曲线和a ...

  5. RFID 读写器 Reader Writer Cloner

    RFID读写器的工作原理 RFID的数据采集以读写器为主导,RFID读写器是一种通过无线通信,实现对标签识别和内存数据的读出和写入操作的装置. 读写器又称为阅读器或读头(Reader).查询器(Int ...

  6. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

  7. Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)

    http://lamda.nju.edu.cn/weixs/project/CNNTricks/CNNTricks.html Deep Neural Networks, especially Conv ...

  8. How to handle Imbalanced Classification Problems in machine learning?

    How to handle Imbalanced Classification Problems in machine learning? from:https://www.analyticsvidh ...

  9. The Model Complexity Myth

    The Model Complexity Myth (or, Yes You Can Fit Models With More Parameters Than Data Points) An oft- ...

  10. (转) Read-through: Wasserstein GAN

    Sorta Insightful Reviews Projects Archive Research About  In a world where everyone has opinions, on ...

随机推荐

  1. Hi3516开发笔记(七):Hi3516虚拟机交叉开发环境搭建之交叉编译Qt

    海思开发专栏 上一篇:<Hi3516开发笔记(六):通过HiTools使用USB/串口将uboot.kernel.rootfs和userdata按照分区表烧写镜像>下一篇:<Hi35 ...

  2. django学习第十一天---django操作cookie和session

    Cookie cookie解析 会话 http协议是无状态的,无连接的 导致每次客户端访问服务端需要登录成功之后才能访问的页面,都需要用户再重新登录一遍,用户体验极差. 客户端想了个办法,cookie ...

  3. 数据抽取平台pydatax介绍--实现和项目使用

    数据抽取平台pydatax实现过程中,有2个关键点: 1.是否能在python3中调用执行datax任务,自己测试了一下可以,代码如下:    这个str1就是配置的shell文件 try: resu ...

  4. 前后端分离项目(八):后端报错Field 'id' doesn't have a default value

    好家伙,又到了修bug的环节,(深叹一口气) 好了,来看报错 2022-10-29 23:27:52.155 WARN 15068 --- [nio-8011-exec-2] o.h.engine.j ...

  5. Spring + JAX-WS : ‘xxx’ is an interface, and JAXB can’t handle interfaces 错误解决方法

    错误栈 Caused by: com.sun.xml.bind.v2.runtime.IllegalAnnotationsException: 2 counts of IllegalAnnotatio ...

  6. 【Azure 媒体服务】在Azure Media Service门户中使用HLS模式传输视频流,播放视频步骤

    问题描述 如何在Azure Media Service门户中使用HLS模式传输视频流,播放视频步骤 问题解决 第一步:在 Media Service 这边点击资产.上传本地视频资源作为Media Se ...

  7. Dapr v1.13 版本已发布

    Dapr是一套开源.可移植的事件驱动型运行时,允许开发人员轻松立足云端与边缘位置运行弹性.微服务.无状态以及有状态等应用程序类型.Dapr能够确保开发人员专注于编写业务逻辑,而不必分神于解决分布式系统 ...

  8. Codeforces Round 920 (Div. 3)(A~F)

    目录 A B C D E F A 按题意模拟即可 #include <bits/stdc++.h> #define int long long #define rep(i,a,b) for ...

  9. C1. Good Subarrays (Easy Version)

    思路:我们枚举每一个左端点,对于每一个左端点,寻找最长的满足条件的区间,这个区间长度就是左端点对答案的贡献,可以发现具有单调性,右端点只会前进不会倒退.所以我们两个指针各扫一遍区间就可以. #incl ...

  10. 最小生成树(二)Prim算法

    一.思想 1.1 基本概念 加权无向图的生成树:一棵含有其所有顶点的无环连通子图. 最小生成树(MST):一棵权值最小(树中所有边的权值之和)的生成树. 1.2 算法原理 1.2.1 切分定理 切分定 ...