目标检测 | RetinaNet:Focal Loss for Dense Object Detection
论文分析了one-stage网络训练存在的类别不平衡问题,提出能根据loss大小自动调节权重的focal loss,使得模型的训练更专注于困难样本。同时,基于FPN设计了RetinaNet,在精度和速度上都有不俗的表现
论文:Focal Loss for Dense Object Detection
Introduction
目前state-of-the-art的目标检测算法大都是two-stage、proposal-driven的网络,如R-CNN架构。而one-stage检测器一直以速度为特色,在精度上始终不及two-stage检测器。因此,论文希望研究出一个精度能与two-stage检测器媲美的one-stage检测器
通过分析,论文认为阻碍one-stage精度主要障碍是类别不平衡问题(class imbalance):
- 在R-CNN架构检测器中,通过two-stage级联和抽样探索法(sampling heuristics)来解决类别不平衡问题。proposal阶段能迅速地将bndbox的数量缩小到很小的范围(1-2k),过滤了大部分背景。而第二阶段,则通过抽样探索法来保持正负样本的平衡,如固定的正负样本比例(1:3)和OHEM
- one-stage检测器通常需要处理大量的bndbox(~100k),密集地覆盖着各位置、尺度和长宽比。然而大部分bndbox都是不含目标的,即easy background。尽管可以使用类似的抽样探索法(如hard example mining)来补救,但这样的效率不高,因为训练过程仍然被简单的背景样本主导,导致模型更多地学习了背景而没有很好地学习检测的目标
在解决以上问题的同时,论文产出了两个成果:
- 新的损失函数focal loss,该函数能够动态地调整交叉熵大小。当类别的置信度越大,权重就逐渐减少,最后变为0。反之,置信度低的类别则得到大的权重
- 设计了一个简单的one-stage检测器RetinaNet来演示focal loss的有效性。该网络包含高效的特征金字塔和特别的anchor设定,结合一些多种近期的one-stage detectgor的trick(DNN/FPN/YOLO/SSD),达到39.1的AP精度和5fps的速度,超越了所有的单模型,如图2所示
FocalLoss
Balanced Cross Entropy
交叉熵损失函数如图1最上曲线,当置信度大于0.5时,loss的值也不小。若存在很多简单样本时,这些不小的loss堆积起来会对少样本的类别训练造成影响
一种简单的做法是赋予不同的类不同的权重$\alpha$,即$\alpha$-balanced 交叉熵。在实际操作中,$\alpha$属于一个预设的超参,类别的样本数越多,$\alpha$则设置越小
Focal Loss Definition
$\alpha$-balanced 交叉熵仅根据正负样本的数量进行权重的平衡,没有考虑样本的难易程度。因此,focal loss降低了容易样本的损失,从而让模型更专注于难的负样本
focal loss在交叉熵的基础上添加了调节因子$(1-p_t)^{\gamma}$,其中$\gamma\ge0$是超参数。$\gamma\in[0,5]$的loss曲线如图1所示,focal loss有两个特性:
- 当一个样本被误分且置信度很低时,调节因子会接近1,整体的loss都很小。当置信度接近1的时候,调节因子会接近于0,整体的loss也被降权了
- 超参数$\gamma$平滑地调整了简单样本的降权比例。当$\gamma=0$,Focal loss与交叉熵一致,随着$\gamma$增加,调节因子的影响也相应增加。当$\gamma=2$时,置信度为0.9的样本的loss将有100倍下降,而0.968的则有1000倍下降,这变相地增加了误分样本的权重
实际使用时中,focal loss会添加$\alpha$-balanced,这是从后面的实验中总结出来的
Class Imbalance and Model Initialization
二分类模型初始化时对于正负样本预测是均等的,而在训练时,样本数多的类别会主导网络的学习,导致训练初期不稳定。为了解决这问题,论文在模型初始化的时候设置先验值$\pi$(如0.01),使模型初始输出$\pi$偏向于低置信度来加大少数(正)样本的学习。在样本不平衡情况下,这种方法对于提高focal loss和 cross entropy训练稳定性有很大帮助
RetinaNet Detector
Architecture
RetinaNet是one-stage架构,由主干网络和两个task-specific子网组成。主干网络用于提取特征,第一个子网用于类别分类,第二个子网用于bndbox回归
- Feature Pyramid Network Backbone
RetinaNet采用FPN作为主干,FPN通过自上而下的路径以及横行连接来增强卷积网络的特征提取能力,能够从一张图片中构造出丰富的以及多尺度特征金字塔,结构如图3(a)-(b)。
FPN构建在ResNet架构上,分别在level $p_3$-$p_7$,每个level l意味着$2^l$的尺度缩放,且每个level包含256通道
- Anchors
level$p_3$到$p_7$对应的anchor尺寸为$322$到$5122$,每个金字塔层级的的长宽比均为${1:2, 1:1, 2:1 }$,为了能够预测出更密集的目标,每个长宽比的anchor添加原设定尺寸的${2^0, 2^{1/3}, 2^{2/3} }$大小的尺寸,每个level总共有9个anchor
每个anchor赋予长度为K的one-hot向量和长度为4的向量,K为类别数,4为box的坐标,与RPN类似。IoU大于0.5的anchor视为正样本,设定其one-host向量的对应值为1,$[0, 0.4)$的anchor视为背景,$[0.4, 0.5)$的anchor不参与训练
- Classification Subnet
分类子网是一个FCN连接FPN的每一level,分类子网是权值共享的,即共用一个FPN。子网由4xCx(3x3卷积+ReLU激活层)+KxA(3x3卷积)构成,如图3(c),C=256,A=9
- Box Regression Subnet
定位子网结构与分类子网类似,只是将最后的卷积大小改为4xAx3x3,如图3(d所示)。每个anchor学习4个参数,代表当前bndbox与GT间的偏移量,这个与R-CNN类似。这里的定位子网是类不可知的(class-agnostic),这样能大幅减少参数量
Inference and Training
- Inference
由于RetinaNet结构简单,在推理的时候只需要直接前向推算即可以得到结果。为了加速预测,每一个FPN level只取置信度top-1k bndbox($\ge0.05$),之后再对所有的结果进行NMS($\ge0.5$)
- Focal Loss
训练时,focal loss直接应用到所有~100k anchor中,最后将所有的loss相加再除以正样本的数量。这里不除以achor数,是由于大部分的bndbox都是easy样本,在focal loss下仅会产生很少loss。权值$\alpha$的设定与$\lambda$存在一定的关系,当$\lambda$增加时,$\alpha$则需要减少,($\alpha=0.25, \lambda=2$表现最好)
- Initialization
Backbone是在ImageNet 1k上预训练的模型,FPN的新层则是根据论文进行初始化,其余的新的卷积层(除了最后一层)则偏置$b=0$,权重为$\sigma=0.01$的高斯分布
$$\pi=\frac{1}{1+e^{-b}}$$
最后一层卷积的权重为$\sigma=0.01$的高斯分布,偏置$b=-log(1-\pi)/\pi$(偏置值的计算是配合最后的激活函数来推),使得训练初期的前景置信度输出为$\pi=0.01$,即认为大概率都是背景。这样背景就会输出很小的loss,前景会输出很大的loss,从而阻止背景在训练前期产生巨大的干扰loss
- Optimization
RetinaNet使用SGD作为优化算法,8卡,每卡batchSize=2。learning rate=0.01,60k和80k轮下降10倍,共进行90k迭代,Weight decay=0.0001,momentum=0.9,
training loss为focal loss与bndbox的smooth L1 loss
Experiments
Training Dense Detection
- Network Initialization
论文首先尝试直接用标准交叉熵进行RetinaNet的训练,不添加任何修改和特殊初始化,结果在训练时模型不收敛。接着论文使用先验概率$\pi=0.01$对模型进行初始化,模型开始正常训练,并且最终达到30.2AP,训练对$\pi$的值不敏感
- Balanced Cross Entropy
接着论文进行平衡交叉熵的实验,结果如Table1a,当$\alpha=0.75$时,模型获得0.9的AP收益
- Focal Loss
接着论文进行了focal loss实验,结果如Table 1b,当$\gamma=2$时,模型在$\alpha$-balanced交叉熵上获得2.9AP收益
论文观察到,$\gamma$与$\alpha$成反向关。整体而言,$\gamma$带来的收益更大,此外,$\alpha$的值一般为$[0.25, 0.75]$(从$\alpha\in[0.01, 0.999]$中实验得出)
- Analysis of the Focal Loss
为了进一步了解focal loss,论文分析了一个收敛模型($\gamma=2$,ResNet-101)的loss经验分布。首先在测试集的预测结果中随机取$10^5$个正样本和$10^7$个负样本,计算其FL值,再对其进行归一化令他们的和为1,最后根据归一化后的loss进行排序,画出正负样本的累积分布函数(CDF),如图4
不同的$\gamma$值下,正样本的CDF曲线大致相同,大约20%的难样本占据了大概一半的loss,随着$\gamma$的增大,更多的loss集中中在top20%中,但变化比较小
不同的$\gamma$值下,负样本的CDF曲线截然不同。当$\gamma=0$时,正负样本的CDF曲线大致相同。当$\gamma$增大时,更大的loss集中在难样本中。当$\gamma=2$时,很大一部分的loss集中在很小比例的负样本中。可以看出,focal loss可以很有效的减少容易样本的影响,让模型更专注于难样本
- Online Hard Example Mining (OHEM)
OHEM用于优化two-stage检测器的训练,首先根据loss对样本进行NMS,再挑选hightest-loss样本组成minibatches,其中NMS的阈值和batch size都是可调的。与FL不同,OHEM直接去除了简单样本,论文也对比了OHEM的变种,在NMS后,构建minibatch时保持1:3的正负样本比。实验结果如Table 1d,无论是原始的OHEM还是变种的OHEM,实验结果都没有FL的性能好,大约有3.2的AP差异。因此,FL更适用于dense detector的训练
Model Architecture Design
- Anchor Density
one-stage检测器使用固定的网格进行预测,一个提高预测性能的方法是使用多尺度/多长宽比的anchro进行。实验结果如Table 1c,单anchor能达到30.3AP,而使用9 anchors能收获4AP的性能提升。最后,当增加到9anchors时,性能法儿下降了,这说明,当anchor密度已经饱和了
- Speed versus Accuracy
更大Backbone和input size意味着更高准确率和更慢的推理速度,Table 1e展示了这两者的影响,图2展示了RetinaNet与其它主流检测器的性能和速度对比。大尺寸的RetinaNet比大部分的two-stage性能要好,而且速度也更快
- Comparison to State of the Art
与当前的主流one-stage算法对比,RetinaNet大概有5.9的AP提升,而与当前经典的two-stage算法对比,大约有2.3的AP提升,而使用ResNeXt32x8d-101-FPN作为backbone则能进一步提升1.7AP
Conclusion
论文认为类别不平衡问题是阻碍one-stage检测器性能提升的主要问题,为了解决这个问题,提出了focal loss,在交叉熵的基础上添加了调节因子,让模型更集中于难样本的训练。另外,论文设计了one-stage检测器RetinaNet并给出了相当充足的实验结果
创作不易,未经允许不得转载~
更多内容请关注个人微信公众号【晓飞的算法工程笔记】
目标检测 | RetinaNet:Focal Loss for Dense Object Detection的更多相关文章
- 论文阅读笔记四十四:RetinaNet:Focal Loss for Dense Object Detection(ICCV2017)
论文原址:https://arxiv.org/abs/1708.02002 github代码:https://github.com/fizyr/keras-retinanet 摘要 目前,具有较高准确 ...
- [论文理解]Focal Loss for Dense Object Detection(Retina Net)
Focal Loss for Dense Object Detection Intro 这又是一篇与何凯明大神有关的作品,文章主要解决了one-stage网络识别率普遍低于two-stage网络的问题 ...
- focal loss for dense object detection
温故知新 focal loss for dense object detection,知乎上一人的评论很经典.hard negative sampling, 就是只挑出来男神(还是最难追的),而foc ...
- Focal Loss for Dense Object Detection 论文阅读
何凯明大佬 ICCV 2017 best student paper 作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确 ...
- 目标检测--Rich feature hierarchies for accurate object detection and semantic segmentation(CVPR 2014)
Rich feature hierarchies for accurate object detection and semantic segmentation 作者: Ross Girshick J ...
- 目标检测比赛---Google AI Open Images - Object Detection Track
https://www.kaggle.com/c/google-ai-open-images-object-detection-track#Evaluation Submissions are eva ...
- Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection
目录 Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Generalized Focal L ...
- Comparison of SIFT Encoded and Deep Learning Features for the Classification and Detection of Esca Disease in Bordeaux Vineyards(分类MobileNet,目标检测 RetinaNet)
识别葡萄的一种虫害,比较了传统SIFT和深度学习分类,最后还做了目标检测 分类用的 MobileNet,目标检测 RetinaNet MobileNet 是将传统深度可分离卷积分成了两步,深度卷积和逐 ...
- 【论文解读】[目标检测]retinanet
作为单阶段网络,retinanet兼具速度和精度(精度是没问题,速度我持疑问),是非常耐用的一个检测器,现在很多单阶段检测器也是以retinanet为baseline,进行各种改进,足见retinan ...
随机推荐
- flutter实践 - plsy
项目背景 项目需要从钉钉微应用跳转 WPS 打开 word 文档,但是 WPS 只提供了 StartActivity 方式携带参数跳转应用,deeplink 只能打开应用,而钉钉微应用只支持 deep ...
- MatterTrack Route Of Network Traffic :: Matter
Python 1.1 基础 while语句 字符串边缘填充 列出文件夹中的指定文件类型 All Combinations For A List Of Objects Apply Operations ...
- SpringMVC之添加照片并修改照片名字
@RequestMapping(value="/addIdcardsSubmit",method={RequestMethod.POST,RequestMethod.GET}) p ...
- 在python中连接mysql数据库,并进行增删改查
数据库在开发过程中是最常见的,基本上在服务端的编程过程中都会使用到,mysql是较常见的一种数据库,这里介绍python如果连接到数据库中,并对数据库进行增删改查. 安装mysql的python扩展 ...
- MFC中文件对话框类CFileDialog详解及文件过滤器说明
当前位置 : 首页 » 文章分类 : 开发 » MFC中文件对话框类CFileDialog详解及文件过滤器说明 上一篇 利用OpenCV从摄像头获得图像的坐标原点是在左下角 下一篇 Word中为 ...
- Rong's Portfolio
車架貼標設計 Velocite SYN frame decals 以簡潔設計的原則,分別依公路車.登山車.電動車的屬性設計表面塗裝曲線,針對車架特殊造型設計貼標突顯其功能,並搭配品牌基本色與市場偏好色 ...
- api安全规范
1. API签名的目的 校验API调用者的身份,是否有权访问 校验请求的数据完整性,防止被中间人篡改 防止重放攻击 2.基本概念 AccessKey: API使用者向API提供方申请的Ac ...
- C++走向远洋——58(项目二3、动物这样叫、改进版)
*/ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:text.cpp * 作者:常轩 * 微信公众号:Worldhe ...
- python爬虫-MongoDB安装配置
MongoDB安装配置: 在安装配置MongoDB的过程中遇到了很多问题,现在重新梳理一遍安装流程.遇到的问题及其解决方法 系统版本:Windows 10 MongoDB版本:4.2.1 1.下载地址 ...
- java反序列化-ysoserial-调试分析总结篇(4)
1.前言 这篇文章继续分析commoncollections4利用链,这篇文章是对cc2的改造,和cc3一样,cc3是对cc1的改造,cc4则是对cc2的改造,里面chained的invoke变成了i ...