目标检测 | RetinaNet:Focal Loss for Dense Object Detection
论文分析了one-stage网络训练存在的类别不平衡问题,提出能根据loss大小自动调节权重的focal loss,使得模型的训练更专注于困难样本。同时,基于FPN设计了RetinaNet,在精度和速度上都有不俗的表现
论文:Focal Loss for Dense Object Detection
Introduction
目前state-of-the-art的目标检测算法大都是two-stage、proposal-driven的网络,如R-CNN架构。而one-stage检测器一直以速度为特色,在精度上始终不及two-stage检测器。因此,论文希望研究出一个精度能与two-stage检测器媲美的one-stage检测器
通过分析,论文认为阻碍one-stage精度主要障碍是类别不平衡问题(class imbalance):
- 在R-CNN架构检测器中,通过two-stage级联和抽样探索法(sampling heuristics)来解决类别不平衡问题。proposal阶段能迅速地将bndbox的数量缩小到很小的范围(1-2k),过滤了大部分背景。而第二阶段,则通过抽样探索法来保持正负样本的平衡,如固定的正负样本比例(1:3)和OHEM
- one-stage检测器通常需要处理大量的bndbox(~100k),密集地覆盖着各位置、尺度和长宽比。然而大部分bndbox都是不含目标的,即easy background。尽管可以使用类似的抽样探索法(如hard example mining)来补救,但这样的效率不高,因为训练过程仍然被简单的背景样本主导,导致模型更多地学习了背景而没有很好地学习检测的目标
在解决以上问题的同时,论文产出了两个成果:
- 新的损失函数focal loss,该函数能够动态地调整交叉熵大小。当类别的置信度越大,权重就逐渐减少,最后变为0。反之,置信度低的类别则得到大的权重
- 设计了一个简单的one-stage检测器RetinaNet来演示focal loss的有效性。该网络包含高效的特征金字塔和特别的anchor设定,结合一些多种近期的one-stage detectgor的trick(DNN/FPN/YOLO/SSD),达到39.1的AP精度和5fps的速度,超越了所有的单模型,如图2所示
FocalLoss
Balanced Cross Entropy
交叉熵损失函数如图1最上曲线,当置信度大于0.5时,loss的值也不小。若存在很多简单样本时,这些不小的loss堆积起来会对少样本的类别训练造成影响
一种简单的做法是赋予不同的类不同的权重$\alpha$,即$\alpha$-balanced 交叉熵。在实际操作中,$\alpha$属于一个预设的超参,类别的样本数越多,$\alpha$则设置越小
Focal Loss Definition
$\alpha$-balanced 交叉熵仅根据正负样本的数量进行权重的平衡,没有考虑样本的难易程度。因此,focal loss降低了容易样本的损失,从而让模型更专注于难的负样本
focal loss在交叉熵的基础上添加了调节因子$(1-p_t)^{\gamma}$,其中$\gamma\ge0$是超参数。$\gamma\in[0,5]$的loss曲线如图1所示,focal loss有两个特性:
- 当一个样本被误分且置信度很低时,调节因子会接近1,整体的loss都很小。当置信度接近1的时候,调节因子会接近于0,整体的loss也被降权了
- 超参数$\gamma$平滑地调整了简单样本的降权比例。当$\gamma=0$,Focal loss与交叉熵一致,随着$\gamma$增加,调节因子的影响也相应增加。当$\gamma=2$时,置信度为0.9的样本的loss将有100倍下降,而0.968的则有1000倍下降,这变相地增加了误分样本的权重
实际使用时中,focal loss会添加$\alpha$-balanced,这是从后面的实验中总结出来的
Class Imbalance and Model Initialization
二分类模型初始化时对于正负样本预测是均等的,而在训练时,样本数多的类别会主导网络的学习,导致训练初期不稳定。为了解决这问题,论文在模型初始化的时候设置先验值$\pi$(如0.01),使模型初始输出$\pi$偏向于低置信度来加大少数(正)样本的学习。在样本不平衡情况下,这种方法对于提高focal loss和 cross entropy训练稳定性有很大帮助
RetinaNet Detector
Architecture
RetinaNet是one-stage架构,由主干网络和两个task-specific子网组成。主干网络用于提取特征,第一个子网用于类别分类,第二个子网用于bndbox回归
- Feature Pyramid Network Backbone
RetinaNet采用FPN作为主干,FPN通过自上而下的路径以及横行连接来增强卷积网络的特征提取能力,能够从一张图片中构造出丰富的以及多尺度特征金字塔,结构如图3(a)-(b)。
FPN构建在ResNet架构上,分别在level $p_3$-$p_7$,每个level l意味着$2^l$的尺度缩放,且每个level包含256通道
- Anchors
level$p_3$到$p_7$对应的anchor尺寸为$322$到$5122$,每个金字塔层级的的长宽比均为${1:2, 1:1, 2:1 }$,为了能够预测出更密集的目标,每个长宽比的anchor添加原设定尺寸的${2^0, 2^{1/3}, 2^{2/3} }$大小的尺寸,每个level总共有9个anchor
每个anchor赋予长度为K的one-hot向量和长度为4的向量,K为类别数,4为box的坐标,与RPN类似。IoU大于0.5的anchor视为正样本,设定其one-host向量的对应值为1,$[0, 0.4)$的anchor视为背景,$[0.4, 0.5)$的anchor不参与训练
- Classification Subnet
分类子网是一个FCN连接FPN的每一level,分类子网是权值共享的,即共用一个FPN。子网由4xCx(3x3卷积+ReLU激活层)+KxA(3x3卷积)构成,如图3(c),C=256,A=9
- Box Regression Subnet
定位子网结构与分类子网类似,只是将最后的卷积大小改为4xAx3x3,如图3(d所示)。每个anchor学习4个参数,代表当前bndbox与GT间的偏移量,这个与R-CNN类似。这里的定位子网是类不可知的(class-agnostic),这样能大幅减少参数量
Inference and Training
- Inference
由于RetinaNet结构简单,在推理的时候只需要直接前向推算即可以得到结果。为了加速预测,每一个FPN level只取置信度top-1k bndbox($\ge0.05$),之后再对所有的结果进行NMS($\ge0.5$)
- Focal Loss
训练时,focal loss直接应用到所有~100k anchor中,最后将所有的loss相加再除以正样本的数量。这里不除以achor数,是由于大部分的bndbox都是easy样本,在focal loss下仅会产生很少loss。权值$\alpha$的设定与$\lambda$存在一定的关系,当$\lambda$增加时,$\alpha$则需要减少,($\alpha=0.25, \lambda=2$表现最好)
- Initialization
Backbone是在ImageNet 1k上预训练的模型,FPN的新层则是根据论文进行初始化,其余的新的卷积层(除了最后一层)则偏置$b=0$,权重为$\sigma=0.01$的高斯分布
$$\pi=\frac{1}{1+e^{-b}}$$
最后一层卷积的权重为$\sigma=0.01$的高斯分布,偏置$b=-log(1-\pi)/\pi$(偏置值的计算是配合最后的激活函数来推),使得训练初期的前景置信度输出为$\pi=0.01$,即认为大概率都是背景。这样背景就会输出很小的loss,前景会输出很大的loss,从而阻止背景在训练前期产生巨大的干扰loss
- Optimization
RetinaNet使用SGD作为优化算法,8卡,每卡batchSize=2。learning rate=0.01,60k和80k轮下降10倍,共进行90k迭代,Weight decay=0.0001,momentum=0.9,
training loss为focal loss与bndbox的smooth L1 loss
Experiments
Training Dense Detection
- Network Initialization
论文首先尝试直接用标准交叉熵进行RetinaNet的训练,不添加任何修改和特殊初始化,结果在训练时模型不收敛。接着论文使用先验概率$\pi=0.01$对模型进行初始化,模型开始正常训练,并且最终达到30.2AP,训练对$\pi$的值不敏感
- Balanced Cross Entropy
接着论文进行平衡交叉熵的实验,结果如Table1a,当$\alpha=0.75$时,模型获得0.9的AP收益
- Focal Loss
接着论文进行了focal loss实验,结果如Table 1b,当$\gamma=2$时,模型在$\alpha$-balanced交叉熵上获得2.9AP收益
论文观察到,$\gamma$与$\alpha$成反向关。整体而言,$\gamma$带来的收益更大,此外,$\alpha$的值一般为$[0.25, 0.75]$(从$\alpha\in[0.01, 0.999]$中实验得出)
- Analysis of the Focal Loss
为了进一步了解focal loss,论文分析了一个收敛模型($\gamma=2$,ResNet-101)的loss经验分布。首先在测试集的预测结果中随机取$10^5$个正样本和$10^7$个负样本,计算其FL值,再对其进行归一化令他们的和为1,最后根据归一化后的loss进行排序,画出正负样本的累积分布函数(CDF),如图4
不同的$\gamma$值下,正样本的CDF曲线大致相同,大约20%的难样本占据了大概一半的loss,随着$\gamma$的增大,更多的loss集中中在top20%中,但变化比较小
不同的$\gamma$值下,负样本的CDF曲线截然不同。当$\gamma=0$时,正负样本的CDF曲线大致相同。当$\gamma$增大时,更大的loss集中在难样本中。当$\gamma=2$时,很大一部分的loss集中在很小比例的负样本中。可以看出,focal loss可以很有效的减少容易样本的影响,让模型更专注于难样本
- Online Hard Example Mining (OHEM)
OHEM用于优化two-stage检测器的训练,首先根据loss对样本进行NMS,再挑选hightest-loss样本组成minibatches,其中NMS的阈值和batch size都是可调的。与FL不同,OHEM直接去除了简单样本,论文也对比了OHEM的变种,在NMS后,构建minibatch时保持1:3的正负样本比。实验结果如Table 1d,无论是原始的OHEM还是变种的OHEM,实验结果都没有FL的性能好,大约有3.2的AP差异。因此,FL更适用于dense detector的训练
Model Architecture Design
- Anchor Density
one-stage检测器使用固定的网格进行预测,一个提高预测性能的方法是使用多尺度/多长宽比的anchro进行。实验结果如Table 1c,单anchor能达到30.3AP,而使用9 anchors能收获4AP的性能提升。最后,当增加到9anchors时,性能法儿下降了,这说明,当anchor密度已经饱和了
- Speed versus Accuracy
更大Backbone和input size意味着更高准确率和更慢的推理速度,Table 1e展示了这两者的影响,图2展示了RetinaNet与其它主流检测器的性能和速度对比。大尺寸的RetinaNet比大部分的two-stage性能要好,而且速度也更快
- Comparison to State of the Art
与当前的主流one-stage算法对比,RetinaNet大概有5.9的AP提升,而与当前经典的two-stage算法对比,大约有2.3的AP提升,而使用ResNeXt32x8d-101-FPN作为backbone则能进一步提升1.7AP
Conclusion
论文认为类别不平衡问题是阻碍one-stage检测器性能提升的主要问题,为了解决这个问题,提出了focal loss,在交叉熵的基础上添加了调节因子,让模型更集中于难样本的训练。另外,论文设计了one-stage检测器RetinaNet并给出了相当充足的实验结果
创作不易,未经允许不得转载~
更多内容请关注个人微信公众号【晓飞的算法工程笔记】
目标检测 | RetinaNet:Focal Loss for Dense Object Detection的更多相关文章
- 论文阅读笔记四十四:RetinaNet:Focal Loss for Dense Object Detection(ICCV2017)
论文原址:https://arxiv.org/abs/1708.02002 github代码:https://github.com/fizyr/keras-retinanet 摘要 目前,具有较高准确 ...
- [论文理解]Focal Loss for Dense Object Detection(Retina Net)
Focal Loss for Dense Object Detection Intro 这又是一篇与何凯明大神有关的作品,文章主要解决了one-stage网络识别率普遍低于two-stage网络的问题 ...
- focal loss for dense object detection
温故知新 focal loss for dense object detection,知乎上一人的评论很经典.hard negative sampling, 就是只挑出来男神(还是最难追的),而foc ...
- Focal Loss for Dense Object Detection 论文阅读
何凯明大佬 ICCV 2017 best student paper 作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确 ...
- 目标检测--Rich feature hierarchies for accurate object detection and semantic segmentation(CVPR 2014)
Rich feature hierarchies for accurate object detection and semantic segmentation 作者: Ross Girshick J ...
- 目标检测比赛---Google AI Open Images - Object Detection Track
https://www.kaggle.com/c/google-ai-open-images-object-detection-track#Evaluation Submissions are eva ...
- Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection
目录 Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Generalized Focal L ...
- Comparison of SIFT Encoded and Deep Learning Features for the Classification and Detection of Esca Disease in Bordeaux Vineyards(分类MobileNet,目标检测 RetinaNet)
识别葡萄的一种虫害,比较了传统SIFT和深度学习分类,最后还做了目标检测 分类用的 MobileNet,目标检测 RetinaNet MobileNet 是将传统深度可分离卷积分成了两步,深度卷积和逐 ...
- 【论文解读】[目标检测]retinanet
作为单阶段网络,retinanet兼具速度和精度(精度是没问题,速度我持疑问),是非常耐用的一个检测器,现在很多单阶段检测器也是以retinanet为baseline,进行各种改进,足见retinan ...
随机推荐
- Android Studio那些错误的问题们
本片博客会记录关于Android开发工具Android Studio出错的那些问题,包括导入项目编译失败.时间过长,莫名其妙的歇菜等等... 问题 3facets cannot be loaded.Y ...
- 初入 Ubuntu 的一些操作 · Lei's blog
查看系统版本 cat /etc/os-release 修改 root 密码 passwd 新建用户 新建用户: adduser username 将新用户加入 sudo 组,这样就可以用 sudo 命 ...
- find_in_set 函数的语法
find_in_set 函数的语法: FIND_IN_SET(str,strList) str 要查询的字符串 strList 字段名,参数以“,”分隔,如(1,2,6,8) 查询字段(strList ...
- spring入门-整合junit和web
整合Junit 导入jar包 基本 :4+1 测试:spring-test-5.1.3.RELEASE.jar 让Junit通知spring加载配置文件 让spring容器自动进行注入 1234567 ...
- async/await实现图片的串行、并行加载
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- Python爬虫-scrapyd
1.什么是scrapyd Scrapyd是一个服务,用来运行scrapy爬虫的. 它允许你部署你的scrapy项目以及通过HTTP JSON的方式控制你的爬虫. 官方文档:http://scrapyd ...
- 使用Properties配置文件进行配置读取
#使用Properties配置文件进行配置读取: 例如:有一个配置文件的内容如下: # setting.properties last_open_file=/data/hello.txt auto_s ...
- 有点长的博客:Redis不是只有get set那么简单
我以前还没接触Redis的时候,听到大数据组的小伙伴在讨论Redis,觉得这东西好高端,要是哪天我们组也可以使用下Redis就好了,好长一段时间后,我们项目中终于引入了Redis这个技术,我用了几下, ...
- GO - if判断,for循环,switch语句,数组的使用
1.if - else if - else的使用 package main import "fmt" func main() { // 1.简单使用 var a=10 if a== ...
- JAVA 中的反射(reflact)
获取反射加载类(获取类的字节码)的3种方式: Class class1=Class.forName("lession_svc.lession_svc.reflact.Person" ...