【深度学习】Focal Loss 与 GHM——解决样本不平衡问题
Focal Loss 与 GHM
Focal Loss
Focal Loss 的提出主要是为了解决难易样本数量不平衡(注意:这有别于正负样本数量不均衡问题)问题。下面以目标检测应用场景来说明。
一些 one-stage 的目标检测器通常会产生很多数量的 anchor box,但是只有极少数是正样本,导致正负样本数量不均衡。这里假设我们计算分类损失函数为交叉熵公式。
由于在目标检测中,大量的候选目标都是易分样本,这些样本的损失很低,但是由于数量极不平衡,易分样本数量相对来说太多,最终主导了总的损失,但是模型也应该关注那些难分样本(难分样本又分为普通难分样本和特别难分样本,后面即将讲到的GHM就是为了解决特别难分样本的问题)。

基于以上两个场景中的问题,Focal Loss 给出了很好的解决方法:

GHM
Focal Loss存在一些问题:
- 如果让模型过多关注
难分样本会引发一些问题,比如样本中的离群点(outliers),已经收敛的模型可能会因为这些离群点还是被判别错误,总而言之,我们不应该过多关注易分样本,但也不应该过多关注难分样本; - \(\alpha\) 与 \(\gamma\) 的取值全从实验得出,且两者要联合一起实验,因为它们的取值会相互影响。
几个概念:
梯度模长g:\(g\) 正比于检测的难易程度,\(g\) 越大则检测难度越大,\(g\) 从交叉熵损失求梯度得来
\[
g=|p-p^*|=
\begin{cases}
1-p, & \text{if p* = 1} \\
p, & \text{if p* = 0}
\end{cases}
\]
\(p\) 是模型预测的概率,\(p^*\) 是 Ground-Truth 的标签(取值为1或者0);\(g\) 正比于检测的难易程度,\(g\) 越大则检测难度越大;
梯度模长与样本数量的关系:梯度模长接近于 0 时样本数量最多(这些可归类为易分样本),随着梯度模长的增长,样本数量迅速减少,但是当梯度模长接近于 1 时样本数量也挺多(这些可归类为难分样本)。如果过多关注难分样本,由于其梯度模长比一般样本大很多,可能会降低模型的准确度。因此,要同时抑制易分样本和难分样本!

抑制方法之梯度密度 \(G(D)\): 因为易分样本和特别难分样本数量都要比一般样本多一些,而我们要做的就是衰减
单位区间数量多的那类样本,也就是物理学上的密度概念。
\[
GD(g) = \frac{1}{l_{\epsilon}}\sum_{k=1}^{N}\delta_{\epsilon}(g_k, g)
\]
\(\delta_{\epsilon}(g_k, g)\) 表示样本 \(1 \sim N(样本数量)\) 中,梯度模长分布在 \((g-\frac{\epsilon}{2}, g+\frac{\epsilon}{2} )\) 范围内的样本个数,\(l_{\epsilon}(g)\) 代表了 \((g-\frac{\epsilon}{2}, g+\frac{\epsilon}{2} )\) 区间的长度;最后对每个样本,用交叉熵 \(CE\) \(\times\) 该样本梯度密度的倒数即可。
分类问题的GHM损失:
\[
L_{GHM-C} = \sum_{i=1}^{N}\frac{L_{CE}(p_i, p_i^*)}{GD(g_i)}
\]
回归问题的GHM损失:
\[
L_{GHM-R} = \sum_{i=1}^N \frac{ASL_1(d_i)}{GD(gr_i)}
\]
其中,\(ASL_1(d_i)\) 为修正的 smooth L1 Loss。
抑制效果:

参考资料:
5分钟理解Focal Loss与GHM-解决样本不平衡利器——知乎
【深度学习】Focal Loss 与 GHM——解决样本不平衡问题的更多相关文章
- 焦点损失函数 Focal Loss 与 GHM
文章来自公众号[机器学习炼丹术] 1 focal loss的概述 焦点损失函数 Focal Loss(2017年何凯明大佬的论文)被提出用于密集物体检测任务. 当然,在目标检测中,可能待检测物体有10 ...
- 从极大似然估计的角度理解深度学习中loss函数
从极大似然估计的角度理解深度学习中loss函数 为了理解这一概念,首先回顾下最大似然估计的概念: 最大似然估计常用于利用已知的样本结果,反推最有可能导致这一结果产生的参数值,往往模型结果已经确定,用于 ...
- AI佳作解读系列(一)——深度学习模型训练痛点及解决方法
1 模型训练基本步骤 进入了AI领域,学习了手写字识别等几个demo后,就会发现深度学习模型训练是十分关键和有挑战性的.选定了网络结构后,深度学习训练过程基本大同小异,一般分为如下几个步骤 定义算法公 ...
- 深度学习中loss总结
一.分类损失 1.交叉熵损失函数 公式: 交叉熵的原理 交叉熵刻画的是实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近.假设概率分布p为期望输出,概率分布q为实际输 ...
- 论文阅读|Focal loss
原文标题:Focal Loss for Dense Object Detection 概要 目标检测主要有两种主流框架,一级检测器(one-stage)和二级检测器(two-stage),一级检测器, ...
- 处理样本不平衡的LOSS—Focal Loss
0 前言 Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的.在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失.然后我们从样本 ...
- Focal Loss 损失函数简述
Focal Loss 摘要 Focal Loss目标是解决样本类别不平衡以及样本分类难度不平衡等问题,如目标检测中大量简单的background,很少量较难的foreground样本.Focal Lo ...
- 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...
- [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...
随机推荐
- GIT和SVN教程
各种版本控制工具的简单比较 特性 CVS SVN GIT 并发修改 支持 支持 支持 并发提交 不支持 支持 支持 历史轨迹 不支持更名 支持更名 支持更名 分布式 不支持 不支持 支持 SVN SV ...
- 编程题及解题思路(1,String)
题目描述 请实现一个算法,确定一个字符串的所有字符是否全都不同.这里我们要求不允许使用额外的存储结构. 给定一个string iniString,请返回一个bool值,True代表所有字符全都不同,F ...
- arcgis三维球中加载2000坐标系出现错误(The tiling scheme of this layer is not supported by SceneView)
目前我们国家测绘地理信息的坐标体系基准是国家2000坐标系CGCS2000.各类地图组件如OpenLayers.Mapbox.Cesuim和ArcGIS Javascrip等都主要是支持WGS84(w ...
- python 23 继承
目录 继承--inheritance 1. 面向对象继承: 2. 单继承 2.1 类名执行父类的属性.方法 2.2 子类对象执行父类的属性.方法 2.3 执行顺序 2.4 既要执行子类的方法,又要执行 ...
- 从SpringBoot构建十万博文聊聊限流特技
前言 在开发十万博客系统的的过程中,前面主要分享了爬虫.缓存穿透以及文章阅读量计数等等.爬虫的目的就是解决十万+问题:缓存穿透是为了保护后端数据库查询服务:计数服务解决了接近真实阅读数以及数据库服务的 ...
- Python AttributeError: 'Module' object has no attribute 'STARTF_USESHOWINDOW'
夫学须志也,才须学也,非学无以广才,非志无以成学.--诸葛亮 生活有度,自得慈铭 --杜锦阳 今天新来的同事安装环境遇到个莫名其妙的问题: AttributeError: 'Module' objec ...
- CodeForces 988 F Rain and Umbrellas
Rain and Umbrellas 题意:某同学从x=0的点走到x=a的点,路上有几段路程是下雨的, 如果他需要经过这几段下雨的路程, 需要手上有伞, 每一把伞有一个重量, 求走到重点重量×路程的最 ...
- Atcode B - Colorful Hats(思维)
题目链接:http://agc016.contest.atcoder.jp/tasks/agc016_b 题解:挺有意思的题目主要还是模拟出最多有几种不可能的情况,要知道ai的差距不能超过1这个想想就 ...
- codeforces 459 E. Pashmak and Graph(dp)
题目链接:http://codeforces.com/contest/459/problem/E 题意:给出m条边n个点每条边都有权值问如果两边能够相连的条件是边权值是严格递增的话,最长能接几条边. ...
- Linux基础提高_系统性能相关命令
w 看系统的负载信息 用于显示已经登陆系统的用户列表,并显示用户正在执行的指令 uptime [root@localhost]#uptime 17:26:07 up 9:02, 3 users, lo ...