深度学习笔记(八)Focal Loss
论文:Focal Loss for Dense Object Detection
论文链接:https://arxiv.org/abs/1708.02002
一. 提出背景
object detection的算法主要可以分为两大类:two-stage detector和one-stage detector。前者是指类似Faster RCNN,RFCN这样需要region proposal的检测算法,这类算法可以达到很高的准确率,但是速度较慢。虽然可以通过减少proposal的数量或降低输入图像的分辨率等方式达到提速,但是速度并没有质的提升。后者是指类似YOLO,SSD这样不需要region proposal,直接回归的检测算法,这类算法速度很快,但是准确率不如前者。作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确率,同时不影响原有的速度。
作者认为one-stage detector的准确率不如two-stage detector的原因是:样本的类别不均衡导致的。我们知道在object detection领域,一张图像可能生成成千上万的candidate locations,但是其中只有很少一部分是包含object的,这就带来了类别不均衡。那么类别不均衡会带来什么后果呢?引用原文讲的两个后果:(1) training is inefficient as most locations are easy negatives that contribute no useful learning signal; (2) en masse, the easy negatives can overwhelm training and lead to degenerate models. 什么意思呢?负样本数量太大,占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向并不是我们所希望的那样。其实先前也有一些算法来处理类别不均衡的问题,比如OHEM(online hard example mining),OHEM的主要思想可以用原文的一句话概括:In OHEM each example is scored by its loss, non-maximum suppression (nms) is then applied, and a minibatch is constructed with the highest-loss examples。OHEM算法虽然增加了错分类样本的权重,但是OHEM算法忽略了容易分类的样本。
因此针对类别不均衡问题,作者提出一种新的损失函数:focal loss,这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。为了证明focal loss的有效性,作者设计了一个dense detector:RetinaNet,并且在训练时采用focal loss训练。实验证明RetinaNet不仅可以达到one-stage detector的速度,也能有two-stage detector的准确率。
二. focal loss
1.Cross Entropy
对于二分类来说:标准的交叉熵损失:
$CrossEntropy= -\frac{1}{n} \sum_{i=1}^{n} [y_i log(p_i) + (1-y_i) log(1 - log(p_i))]$
或
这里$y$是GT=1/0,$p$是预测输出为1的概率。
我们知道,当$y=1$时:
这时候,$L$与预测输出的关系如下左图所示:很显然:对于正样本的预测,预测输出越接近真实样本标签$y=1$, 损失函数$L$越小;预测输出越接近0,$L$越大。
而当$y=0$时:
这时候,$L$与预测输出的关系如上右图:同样,预测输出越接近真实样本标签0($p$值越小),损失函数$L$越小;预测输出越接近1,$L$越大。函数的变化趋势也完全符合实际需要的情况。
无论真实样本标签 $y$ 是 0 还是 1,$L$ 都表征了预测输出与 $y$ 的差距。从图形中我们可以发现:预测输出与 $y$ 差得越多,$L$ 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签$ y$。
为了方便,用pt代替p,如下公式2:
接下来介绍一个最基本的对交叉熵的改进,也将作为本文实验的baseline。
2.Balanced Cross Entropy
什么意思呢?增加了一个系数at,跟pt的定义类似,当label=1的时候,at=a;当label=-1的时候,at=1-a,a的范围也是0到1。因此可以通过设定a的值(一般而言假如1这个类的样本数比-1这个类的样本数多很多,那么a会取0到0.5来增加-1这个类的样本的权重)来控制正负样本对总的loss的共享权重。这里当a=0.5时就和标准交叉熵一样了(系数是个常数)。
显然前面的公式3虽然可以控制正负样本的权重,但是没法控制容易分类和难分类样本的权重。
3.Focal Loss
这里的$\gamma$ 称作focusing parameter,$\gamma>=0$。
$(1- p_t)^\gamma$ 称为调制系数(modulating factor)
这里介绍下focal loss的两个重要性质:1、当一个样本被分错的时候,pt是很小的(请结合公式2,比如当y=1时,p<0.5才是错分类,此时pt就比较小,反之当y=-1时,p>0.5是错分了),因此调制系数就趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。2、当 $γ=0$ 的时候,focal loss就是传统的交叉熵损失,当 $γ$ 增加的时候,调制系数也会增加。
focal loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失的贡献。
作者在实验中采用的是公式5的focal loss(结合了公式3和公式4,这样既能调整正负样本的权重,又能控制难易分类样本的权重):
PS: 实际我在使用中,选择的是以下方式
FL(p_t) =( - p_t) ^ 1.0 * log(p_t) if p_t 来自正样本
FL(p_t) =( - p_t) ^ gamma * log(p_t) if p_t 来自负样本
即给易分的负样本更大的惩罚。
三. 实验
在实验中a的选择范围也很广,一般而言当γ增加的时候,a需要减小一点(实验中γ=2,a=0.25的效果最好)
实验结果:
Table1是关于RetinaNet和Focal Loss的一些实验结果。(a)是在交叉熵的基础上加上参数a,a=0.5就表示传统的交叉熵,可以看出当a=0.75的时候效果最好,AP值提升了0.9。(b)是对比不同的参数γ和a的实验结果,可以看出随着γ的增加,AP提升比较明显。(d)通过和OHEM的对比可以看出最好的Focal Loss比最好的OHEM提高了3.2AP。这里OHEM1:3表示在通过OHEM得到的minibatch上强制positive和negative样本的比例为1:3,通过对比可以看出这种强制的操作并没有提升AP。(e)加入了运算时间的对比,可以和前面的Figure2结合起来看,速度方面也有优势!注意这里RetinaNet-101-800的AP是37.8,当把训练时间扩大1.5倍同时采用scale jitter,AP可以提高到39.1,这就是全文和table2中的最高的39.1AP的由来。
Figure4是对比forground和background样本在不同γ情况下的累积误差。纵坐标是归一化后的损失,横坐标是总的foreground或background样本数的百分比。可以看出γ的变化对正(forground)样本的累积误差的影响并不大,但是对于负(background)样本的累积误差的影响还是很大的(γ=2时,将近99%的background样本的损失都非常小)。
三. 总结
原文的这段话概括得很好:In this work, we identify class imbalance as the primary obstacle preventing one-stage object detectors from surpassing top-performing, two-stage methods, such as Faster R-CNN variants. To address this, we propose the focal loss which applies a modulating term to the cross entropy loss in order to focus learning on hard examples and down-weight the numerous easy negatives.
深度学习笔记(八)Focal Loss的更多相关文章
- Google TensorFlow深度学习笔记
Google Deep Learning Notes Google 深度学习笔记 由于谷歌机器学习教程更新太慢,所以一边学习Deep Learning教程,经常总结是个好习惯,笔记目录奉上. Gith ...
- Learning ROS forRobotics Programming Second Edition学习笔记(八)indigo rviz gazebo
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...
- 深度学习笔记:优化方法总结(BGD,SGD,Momentum,AdaGrad,RMSProp,Adam)
深度学习笔记:优化方法总结(BGD,SGD,Momentum,AdaGrad,RMSProp,Adam) 深度学习笔记(一):logistic分类 深度学习笔记(二):简单神经网络,后向传播算法及实现 ...
- python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑
python3.4学习笔记(八) Python第三方库安装与使用,包管理工具解惑 许多人在安装Python第三方库的时候, 经常会为一个问题困扰:到底应该下载什么格式的文件?当我们点开下载页时, 一般 ...
- Go语言学习笔记八: 数组
Go语言学习笔记八: 数组 数组地球人都知道.所以只说说Go语言的特殊(奇葩)写法. 我一直在想一个人参与了两种语言的设计,但是最后两种语言的语法差异这么大.这是自己否定自己么,为什么不与之前统一一下 ...
- UFLDL深度学习笔记 (二)SoftMax 回归(矩阵化推导)
UFLDL深度学习笔记 (二)Softmax 回归 本文为学习"UFLDL Softmax回归"的笔记与代码实现,文中略过了对代价函数求偏导的过程,本篇笔记主要补充求偏导步骤的详细 ...
- UFLDL深度学习笔记 (一)反向传播与稀疏自编码
UFLDL深度学习笔记 (一)基本知识与稀疏自编码 前言 近来正在系统研究一下深度学习,作为新入门者,为了更好地理解.交流,准备把学习过程总结记录下来.最开始的规划是先学习理论推导:然后学习一两种开源 ...
- UFLDL深度学习笔记 (七)拓扑稀疏编码与矩阵化
UFLDL深度学习笔记 (七)拓扑稀疏编码与矩阵化 主要思路 前面几篇所讲的都是围绕神经网络展开的,一个标志就是激活函数非线性:在前人的研究中,也存在线性激活函数的稀疏编码,该方法试图直接学习数据的特 ...
- UFLDL深度学习笔记 (六)卷积神经网络
UFLDL深度学习笔记 (六)卷积神经网络 1. 主要思路 "UFLDL 卷积神经网络"主要讲解了对大尺寸图像应用前面所讨论神经网络学习的方法,其中的变化有两条,第一,对大尺寸图像 ...
- UFLDL深度学习笔记 (五)自编码线性解码器
UFLDL深度学习笔记 (五)自编码线性解码器 1. 基本问题 在第一篇 UFLDL深度学习笔记 (一)基本知识与稀疏自编码中讨论了激活函数为\(sigmoid\)函数的系数自编码网络,本文要讨论&q ...
随机推荐
- python入门(二):isinstance、内置函数、常用运算等
1. isinstance(变量名,类型) #判断什么类型 ps: 只支持输入两个参数,输入3个参数会报错 >>> isin ...
- java集合框架(1) hashMap 简单使用以及深度分析(转)
java.util 类 HashMap<K,V>java.lang.Object java.util.AbstractMap<K,V> java.util.Hash ...
- oracle in和exist的区别 not in 和not exist的区别
in 是把外表和内表作hash join,而exists是对外表作loop,每次loop再对内表进行查询.一般大家都认为exists比in语句的效率要高,这种说法其实是不准确的,这个是要区分环境的. ...
- The Moon and Sixpence摘抄
I had not yet learnt how contradictory is human nature; I did not know how much pose there is in the ...
- [SpringBoot]Web综合开发-笔记
Web开发 Json接口开发 @RestController 给类添加 @RestController 即可,默认类中的方法都会以 json 的格式返回. 自定义filter filter作用: 用于 ...
- [原]OpenStreetMap数据瓦片服务性能篇
上文说到如何利用node-mapnik架设OpenStreetMap瓦片服务,解决了有没有的问题.然而这个服务还是比较孱弱,主要表现在以下几个方面: 1. Node.js只能使用CPU的一个核,不能有 ...
- js原生的节点操作API
// yi获取元素节点 //一 :过id的方式( 通过id查找元素,大小写敏感,如果有多个id只找到第一个) document.getElementById('div1'); // 通过类名查找元素, ...
- virtual关键字
出于多态的考虑,为了覆盖, 子类同名覆盖函数(函数名.参数.返回值都相同) virtual void print(): 这也许会使人联想到函数的重载,但稍加对比就会发现两者是完全不同的:(1)重载的几 ...
- electron、vue.js、vuex、element-ui、sqlite3
总结一下这两周的入门之路. 1.安装node.js 过程就是下载:https://nodejs.org/en/,安装,完了在命令行窗口,在任何目录下都可录入node -v应能看到类似反馈 如果提示&q ...
- Codeforces Round #541--1131F. Asya And Kittens(基础并查集)
https://codeforces.com/contest/1131/problem/F #include<bits/stdc++.h> using namespace std; int ...