faster-rcnn的整体流程比较复杂,尤其是数据的预处理部分,流程比较繁琐。我写faster-rcnn系列文章的目的是对该算法的原始版本有个整体的把握,如果需要使用该算法做一些具体的任务,推荐使用mmdetection框架,该框架使用PyTorch写成,相比于原始的基于caffe python接口的版本就简洁优雅多了。下面对改算法的整体过程做一个梳理,分为训练过程(端到端的训练)和测试过程两部分。本文注重算法原理,因此直接从如下的数据输入开始:

layer {
name: 'input-data'
type: 'Python'
top: 'data'
top: 'im_info'
top: 'gt_boxes'
python_param {
module: 'roi_data_layer.layer'
layer: 'RoIDataLayer'
param_str: "'num_classes': 2"
}
}

一、训练过程

1、输入数据经过一个ConvNet得到一个feature map(backbone的输出),记为bk_feat;

2、bk_feat经过1个3x3的Conv得到rpn/output;

3、从rpn/output分出两个分支:一个分支经过1个1x1的Conv,输出通道数为2*num_anchor,得到输出rpn_cls_score,它表示在feature map的每一个位置,每一个anchor内含有物体的概率。另一个分支经过1个1x1的Conv,输出通道数为4*num_anchor,得到输出rpn_bbox_pred,它表示网络预测出的,在feature map的每一个位置,每一个anchor相对于它“负责”的物体的真实边界框的偏移量dx, dy, dw, dh

4、接下来是rpn-data层,该层的工作首先是生成anchor,然后按照一定的规则为anchor打上0(背景),1(前景),-1(忽略)标签,最后计算出每一个anchor相对于它“负责”的gt的偏移量dx, dy, dw, dh(bbox_targets),这个偏移量就是第一阶段的回归目标

5、利用第4步得到的bbox_targets和第3步得到的rpn_bbox_pred产生第一阶段的SmoothL1Loss损失rpn_loss_bbox,同时用第4步得到的label和第3步得到的rpn_cls_score产生第一阶段的SoftmaxWithLoss损失rpn_loss_cls

6、然后是proposal层,该层的工作首先也是生成anchor,然后将anchor和rpn_bbox_pred“相加”,得到proposal。然后经过NMS,从结果中选取分数最大的前2000个作为RPN网络产生的rois

7、接着是roi-data层,该层的工作是从上面的rois中,按照前景:背景=1:3的比例选出总共128个rois,每个rois的label是它“负责”的gt(和该rois交并比最大的gt)的类别,背景rois的label为0。接下来与第4步一样,计算出每一个rois相对于它“负责”的gt的偏移量bbox_targets,作为第二阶段的回归目标。,所不同的是这个偏移量是和类别一一对应的,其它类别的偏移量都为0;

8、利用第6步得到的rois,从bk_feat中截取相应区域的feature,并对这个feature做7x7的RoIPooling得到固定大小的feature,然后该feature经过2个FC,最后分出两个分支:一个分支经过1个FC,输出维度为num_class+1(num_class是数据集的类别个数,不包含背景),得到输出cls_score,它表示每个rois内含有的物体属于每个类别(包含背景)的概率。另一个分支经过1个FC,输出维度为4*(num_class+1),得到输出bbox_pred,它表示每个rois相对于它“负责”的每个类别(包含背景)的物体的真实边界框的偏移量dx, dy, dw, dh

9、用第7步得到的bbox_targets和第8步得到的bbox_pred产生第二阶段的SmoothL1Loss损失loss_bbox,同时用第7步得到的label和第8步得到的cls_score产生第二阶段的SoftmaxWithLoss损失loss_cls

二、测试过程

1、输入数据经过一个ConvNet得到一个feature map(backbone的输出),记为bk_feat;

2、bk_feat经过1个3x3的Conv得到rpn/output;

3、从rpn/output分出两个分支:一个分支经过1个1x1的Conv,输出通道数为2*num_anchor,得到输出rpn_cls_score,它表示在feature map的每一个位置,每一个anchor内含有物体的概率。另一个分支经过1个1x1的Conv,输出通道数为4*num_anchor,得到输出rpn_bbox_pred,它表示网络预测出的,在feature map的每一个位置,每一个anchor相对于它“负责”的物体的真实边界框的偏移量dx, dy, dw, dh

4、接下来是proposal层,该层的工作首先是生成anchor,然后将anchor和rpn_bbox_pred“相加”,得到proposal。之后取分数最大的前6000个(训练时取前12000个)proposal进行NMS操作,overlap阈值是0.7,最后从结果中选取分数最大的前300个(训练时取前2000个)作为RPN网络产生的rois;

5、利用第4步得到的rois,从bk_feat中截取相应区域的feature,并对这个feature做7x7的RoIPooling得到固定大小的feature,然后该feature经过2个FC,最后分出两个分支:一个分支经过1个FC,输出维度为num_class+1(num_class是数据集的类别个数,不包含背景),得到输出cls_prob,它表示每个rois内含有的物体属于每个类别(包含背景)的概率。另一个分支经过1个FC,输出维度为4*(num_class+1),得到输出bbox_pred,它表示每个rois相对于它“负责”的每个类别(包含背景)的物体的真实边界框的偏移量dx, dy, dw, dh

6、将第4步得到的rois和第5步得到的bbox_pred“相加”得到bbox,它的shape为(num_rois, 4*(num_class+1))。因此,每一个bbox都和唯一的一个类别相对应,bbox的分数即为第5步得到的cls_prob。接下来对每一类物体,筛选出分数大于一定阈值(0.05)的bbox做NMS(overlap阈值是0.3),最后把所有类别的结果合并起来,得到的所有bbox和它们相应的分数即为最终的检测结果。该过程的部分代码摘录如下:

rois = net.blobs['rois'].data.copy()
# unscale back to raw image space
boxes = rois[:, 1:5] / im_scales[0]
scores = blobs_out['cls_prob']
box_deltas = blobs_out['bbox_pred']
pred_boxes = bbox_transform_inv(boxes, box_deltas) im = cv2.imread(imdb.image_path_at(i))
scores, boxes = im_detect(net, im, box_proposals)
# skip j = 0, because it's the background class
for j in xrange(1, imdb.num_classes):
inds = np.where(scores[:, j] > thresh)[0]
cls_scores = scores[inds, j]
cls_boxes = boxes[inds, j*4:(j+1)*4]
cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
.astype(np.float32, copy=False)
keep = nms(cls_dets, cfg.TEST.NMS)
cls_dets = cls_dets[keep, :]
if vis:
vis_detections(im, imdb.classes[j], cls_dets)
all_boxes[j][i] = cls_dets

详细过程见函数test_netim_detect

faster-rcnn算法总结的更多相关文章

  1. 第三十一节,目标检测算法之 Faster R-CNN算法详解

    Ren, Shaoqing, et al. “Faster R-CNN: Towards real-time object detection with region proposal network ...

  2. 【目标检测】Faster RCNN算法详解

    Ren, Shaoqing, et al. “Faster R-CNN: Towards real-time object detection with region proposal network ...

  3. 目标检测算法之Faster R-CNN算法详解

    Fast R-CNN存在的问题:选择性搜索,非常耗时. 解决:加入一个提取边缘的神经网络,将候选框的选取交给神经网络. 在Fast R-CNN中引入Region Proposal Network(RP ...

  4. faster rcnn算法及源码及论文解析相关博客

    1. 通过代码理解faster-RCNN中的RPN http://blog.csdn.net/happyflyy/article/details/54917514 2. faster rcnn详解 R ...

  5. Faster RCNN算法训练代码解析(3)

    四个层的forward函数分析: RoIDataLayer:读数据,随机打乱等 AnchorTargetLayer:输出所有anchors(这里分析这个) ProposalLayer:用产生的anch ...

  6. Faster RCNN算法训练代码解析(1)

    这周看完faster-rcnn后,应该对其源码进行一个解析,以便后面的使用. 那首先直接先主函数出发py-faster-rcnn/tools/train_faster_rcnn_alt_opt.py ...

  7. Faster RCNN算法demo代码解析

    一. Faster-RCNN代码解释 先看看代码结构: Data: This directory holds (after you download them): Caffe models pre-t ...

  8. Faster RCNN算法训练代码解析(2)

    接着上篇的博客,我们获取imdb和roidb的数据后,就可以搭建网络进行训练了. 我们回到trian_rpn()函数里面,此时运行完了roidb, imdb = get_roidb(imdb_name ...

  9. 目标检测-Faster R-CNN

    [目标检测]Faster RCNN算法详解 Ren, Shaoqing, et al. “Faster R-CNN: Towards real-time object detection with r ...

  10. Paper Reading:Faster RCNN

    Faster R-CNN 论文:Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks 发表时间: ...

随机推荐

  1. 批量更新数据(BatchUpdate)

    批量更新数据(BatchUpdate) /// <summary> /// 批量更新数据,注意:如果有timestamp列,要移除 /// </summary> /// < ...

  2. A*启发式搜索基础

    A*启发式搜索基础 传统的搜索方式是盲目搜索,即到每一步的时候并没有对每种情况进行有效的区分,这样的结果是浪费了大量的时间,对很多没有必要的数据进行了搜索. 而A*算法则在搜索的过程中会选取认为“最优 ...

  3. bzoj1046题解

    [解题思路] 先倒着求一遍LIS,然后对于每个询问L从左到右找到第一个大于等于L的上升序列即可.复杂度O(N(log2N+M)). [参考代码] #pragma GCC optimize(2) #in ...

  4. cf期望概率专题

    cf1009E:求到第i段期望和的比较困难,但是单独求每段的期望是比较容易的,所以单独对每段求和,然后累计总和 E[i]=1/2*a1+1/4*a2+...+1/2^(i-1)*ai-1+1/2^(i ...

  5. php linux下安装xml扩展

    1.进入PHP安装源码包,找到ext下的ftp,进入 cd /home/local/php-5.6.25/ext/xml 2./usr/local/php/bin/phpize 3../configu ...

  6. python Pool并行执行

    # -*- coding: utf-8 -*- import time from multiprocessing import Pool def run(fn): #fn: 函数参数是数据列表的一个元 ...

  7. TLS/SSL 协议 - ServerKeyExchange、ServerHelloDone

    ServerKeyExchange ServerKeyExchange消息的目的是携带密钥交换的额外数据.消息内容对于不同的协商算法套件都会存在差异.在某些场景中,服务器不需要发送任何内容,这意味着在 ...

  8. 1、Monkey环境搭建

    步骤: 1.下载adb压缩包: 32位计算机,用这个包:64位计算机,用这个包: 2.把对应的adb压缩包在本地解压,然后把解压后的文件里面的文件夹拷贝到D盘(当然随便你放在哪个目录)根目录,注意路径 ...

  9. 使用Netfilter进行数据包分析

    #include <linux/init.h>#include <linux/module.h>#include <linux/skbuff.h>#include ...

  10. 并发编程(五)——GIL全局解释器锁、死锁现象与递归锁、信号量、Event事件、线程queue

    GIL.死锁现象与递归锁.信号量.Event事件.线程queue 一.GIL全局解释器锁 1.什么是全局解释器锁 GIL本质就是一把互斥锁,相当于执行权限,每个进程内都会存在一把GIL,同一进程内的多 ...