RPNHead类包含的函数:

(1)_init_():初始化函数

(2)_init_layers():设置Head中的卷积层

(3)forward_single():单尺度特征图的前向传播

(4)loss:Head损失函数计算

(5)_get_bboxes_single():将单个图像的输出转换为bbox预测

(6)_bbox_post-processing_method:bbox后续处理方法

这里介绍的是_init_layers_()函数:

 1 def _init_layers(self):
2 """Initialize layers of the head."""
3 if self.num_convs > 1:
4 rpn_convs = []
5 for i in range(self.num_convs):
6 if i == 0:
7 in_channels = self.in_channels
8 else:
9 in_channels = self.feat_channels
10 # use ``inplace=False`` to avoid error: one of the variables
11 # needed for gradient computation has been modified by an
12 # inplace operation.
13 rpn_convs.append(
14 ConvModule(
15 in_channels,
16 self.feat_channels,
17 3,
18 padding=1,
19 inplace=False))
20 self.rpn_conv = nn.Sequential(*rpn_convs)
21 else:
22 self.rpn_conv = nn.Conv2d(self.in_channels, self.feat_channels, 3, padding=1)
24 self.rpn_cls = nn.Conv2d(self.feat_channels, self.num_base_priors * self.cls_out_channels, 1)
27 self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4, 1)

in_channels(int):输入特征映射中的通道数。

feat_channels(int):隐藏通道的数量。

函数说明:

这个函数是完成Head中卷积层的设置。代码14行可以看到,这里使用的是MMCV中的ConVModule类来构建卷积层,使用他的方便之处在于,它会在卷积层后自动加上归一化层和激活函数。

RPNHead的卷积层主要由三个部分组成,rpn_conv, rpn_cls, rpn_reg。num_cls的值影响rpn_conv的层数。

num_cls的值可在配置文件中rpn_head的字典里设置,默认是1。

self.num_convs=1,RPNHead的结构是:

如果self.num_convs>1,RPNHead的结构如下:

可以看到,通过3×3的卷积层之后,会再经过分类分支和回归分支,用于完成目标的分类和定位。rpn_cls和rpn_reg都是1×1的卷积层。输入通道是feat_channels,输出通道分别是cls_out_channels*num_base_priors和num_base_priors*4。

cls_out_channels,num_base_priors都是RPNHead继承自父类的参数。

关于cls_out_channels属性值的代码如下:

1 if self.use_sigmoid_cls:
2 self.cls_out_channels = num_classes
3 else:
4 self.cls_out_channels = num_classes + 1

如果use_sigmoid_cls为真,cls_out_channels就是类别数,否则是类别数加一。

use_sigmoid的变量值是从loss_cls的配置字典的获取的。默认为False,可以在配置文件中查看是否设置了真值。

1 self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)

关于num_base_priors属性值的代码如下

1 self.prior_generator = build_prior_generator(anchor_generator)
2 # Usually the numbers of anchors for each level are the same
3 # except SSD detectors. So it is an int in the most dense
4 # heads but a list of int in SSDHead
5 self.num_base_priors = self.prior_generator.num_base_priors[0]

这里看不出来,这个值具体是啥,我上网查了一番后得到,num_base_priors = num(anchor_scales)*num(anchor_ratios)。

num_base_priors是每个特征点产生的锚框的数量。

由此可以知道,分类和回归的输出通道的含义是

cls_out_channels*num_base_priors,所有锚框对应的类别分类(这里的类别指的是,是否是目标,不是具体的目标类别)

num_base_priors*4,所有的锚框对应的回归值的输出。(对应的是中心点的偏移量和宽高的缩放量)

mmdetection RPNHead--_init_layers()的更多相关文章

  1. 在mmdetection中跑通MaskRCNN

    1.将数据集转化成COCO格式数据集 Kaggle->COCO: https://github.com/pascal1129/airbus_rle_to_coco/blob/master/1_s ...

  2. anaconda中安装mmdetection

    1.新建conda环境(有则跳过)     conda create -n py36 python=3.6 && source activate py36 2.安装pytorch    ...

  3. mmdetection安装教程

    如果官方教程不行再参考我的吧,我的环境如下: ubuntu cuda10 cudnn7.5 步骤: 1.使用conda创建一个虚拟环境 conda create -n mmdetection pyth ...

  4. 商汤开源的mmdetection技术报告

    目录 1. 简介 2. 支持的算法 3. 框架与架构 6. 相关链接 前言:让我惊艳的几个库: ultralytics的yolov3,在一众yolov3的pytorch版本实现算法中脱颖而出,收到开发 ...

  5. 【AI-人工智能-mmdetection】ModuleNotFoundError: No module named 'mmdet.version'

    在集成 mmdetection 框架时遇到这样的问题. ModuleNotFoundError: No module named 'mmdet.version' mmdetection 框架搭建过程很 ...

  6. mmdetection源码剖析(1)--NMS

    mmdetection源码剖析(1)--NMS 熟悉目标检测的应该都清楚NMS是什么算法,但是如果我们要与C++和cuda结合直接写成Pytorch的操作你们清楚怎么写吗?最近在看mmdetectio ...

  7. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  8. 安装mmdetection,运行报错Segmentation fault

    具体安装过程详见https://github.com/open-mmlab/mmdetection/blob/master/docs/INSTALL.md 在安装完成mmdetection后运行tes ...

  9. mmdetection训练出现nan

    训练出现nan 在使用MMDetection训练模型时,发现打印信息中出现了很多nan.现象是,loss在正常训练下降的过程中,突然变为nan. 梯度裁减 在模型配置中加上grad_clip: opt ...

  10. mmdetection源码阅读

    2021-11-23号更新 mmdetection中的hook函数 参考: 重难点总结: # step1: 根据官方文档,getattr(self,'name')等同于self.name # sept ...

随机推荐

  1. Jmeter 实现Json格式接口测试

    接口Request Headers中的Content-Type和和charset 在"HTTP请求"中添加UTF-8 在"HTTP信息头管理器"中添加Conte ...

  2. Windows Defender 实时防护打不开,你的IT管理员已经限制对此应用一些区域的访问

    最近在使用电脑的时候,Windows Defender实时防护不能使用,一打开就自动关闭,并且显示 该页面不可用 你的IT管理员已经限制对此应用一些区域的访问,实时防护页面显示 正在使用其他防护软件. ...

  3. tp5上传图片常规

    前端不多说,就是使用input标签的file格式. tp5用request()->file('input的名字')接收图片,是binary格式的数据: $file = request()-> ...

  4. esxi虚拟机定时创建快照

    1.vim-cmd vmsvc/getallvms  列出所有虚拟机信息 2.获取需要备份的虚拟机的Vmid 3.执行快照  vim-cmd vmsvc/snapshot.create Vmid $( ...

  5. 常见的hash数据结构

    遍历 hash表是一种比较简单和直观的数据结构,在查找时也有很好的性能.但是hash表不能提供有序遍历,这个是其特性决定,所以不足为奇.但是,更为实际的一个问题是如果遍历整个hash表中的所有元素? ...

  6. Java-面向对象基础 this& 重载

    1.this表示当前对象 获取当前对象的属性 使用this调用当前属性 2.重载 如果两个方法的方法名相同,但参数不一致,那么可以说一个方法是另一个方法的重载

  7. 怎么理解超几何分布概率公式:p=C(M,k)C(N-M,n-k)/C(N,n)

    怎么理解超几何分布概率公式:p=C(M,k)C(N-M,n-k)/C(N,n) 前言:重在记录,可能出错. 超几何分布概率公式:p=C(M,k)C(N-M,n-k)/C(N,n),也就是: 到底要怎么 ...

  8. 三种方式实现RPC调用

    一:RabbitMQ实现RPC调用 客户端: import pika import uuid class FibonacciRpcClient(object): def __init__(self): ...

  9. git 代码强制回滚操作整理(线上线下一起)

    线上代码强制回滚操作,这边整理了一下 1.到线上 执行 git reset --hard xxxxxxxxxxx(更新前的一个版本)2.本地执行 和上面一样 git reset --hard xxxx ...

  10. py打包工具

    库地址: auto-py-to-exe https://pypi.org/project/auto-py-to-exe/ Gooey https://pypi.org/project/Gooey/ 为 ...