cascade DecodeBBox层
https://zhuanlan.zhihu.com/p/36095768
我的推断,第二第三阶段应该不是把所有anchor进行bounding box regression,然后再选取当前条件下的所有roi,而是第一阶段选取512个roi,然后把在第一阶段匹配好的roi送到第二、三阶段
layer {
name: "proposals_2nd"
type: "DecodeBBox"
bottom: "bbox_pred"
bottom: "rois"
bottom: "match_gt_boxes"
top: "proposals_2nd"
bbox_reg_param {
bbox_mean: 0 bbox_mean: 0 bbox_mean: 0 bbox_mean: 0
bbox_std: 0.1 bbox_std: 0.1 bbox_std: 0.2 bbox_std: 0.2
}
propagate_down: 0
propagate_down: 0
propagate_down: 0
}
这段代码就证明了这个想法:rois来自于第一阶段proposal_info,这些rois也是在第一阶段做roi-pooling用来训练的。
个人感觉cascade的模型就是4张图提取512个roi进行训练,然后经过第一阶段训练后,把这512个roi经过回归精修然后去除回归后x1大于x2和y1大于y2的和回归后和gt的iou大于0.95的,这样roi可能就没有512个了.把这些输入给第二阶段的proposal_info_2nd,让这个层再去决定训练样本,这样大可能训练的数据是不足512,并且3个阶段其实都是训练的同一个批roi,也就是说第一阶段进去的那些roi,后面几个阶段实际上也在训练他们,而不是新出来的框DecodeBbox层的输入是bbox_pred,rois和match_gt_boxes.首先明确一点,rpn网络会输出很多proposals出来,ProposalTarget层将这些proposals和gt算iou,确定正负样本并选取1:3的比例,然后输出rois,rois就是拿来具体训练的从rpn中获得那部分预提取框.DecodeBbox层就是将这些原本的rois回归成更精准的框,也就是在原始的rois的坐标上增加经过训练得到的回归的值,这个是通过DecodeBBoxesWithPrior函数实现.DecodeBbox层分为大致3个步骤:1.回归得到更精准的rois 2.去掉回归后x1大于x2和y1大于y2的框 3.去掉回归后和gt的iou大于0.95的框
这部分的结果
#include <cfloat>
#include <vector> #include "caffe/util/bbox_util.hpp"
#include "caffe/layers/decode_bbox_layer.hpp" namespace caffe { template <typename Dtype>
void DecodeBBoxLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// bbox mean and std
BBoxRegParameter bbox_reg_param = this->layer_param_.bbox_reg_param();
bbox_mean_.Reshape(,,,); bbox_std_.Reshape(,,,);
if (bbox_reg_param.bbox_mean_size() > && bbox_reg_param.bbox_std_size() > ) {
int num_means = this->layer_param_.bbox_reg_param().bbox_mean_size();
int num_stds = this->layer_param_.bbox_reg_param().bbox_std_size();
CHECK_EQ(num_means,); CHECK_EQ(num_stds,);
for (int i = ; i < ; i++) {
bbox_mean_.mutable_cpu_data()[i] = bbox_reg_param.bbox_mean(i);
bbox_std_.mutable_cpu_data()[i] = bbox_reg_param.bbox_std(i);
CHECK_GT(bbox_std_.mutable_cpu_data()[i],);
}
} else {
caffe_set(bbox_mean_.count(), Dtype(), bbox_mean_.mutable_cpu_data());
caffe_set(bbox_std_.count(), Dtype(), bbox_std_.mutable_cpu_data());
}
} template <typename Dtype>
void DecodeBBoxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// bottom: bbox_blob, prior_blob, (match_gt_boxes)
CHECK_EQ(bottom[]->num(),bottom[]->num());
if (bottom.size()>=) {
CHECK_EQ(bottom[]->num(),bottom[]->num());
CHECK(this->phase_ == TRAIN);
}
CHECK_EQ(bottom[]->channels(),);
CHECK_EQ(bottom[]->channels(),);
bbox_pred_.ReshapeLike(*bottom[]);
top[]->ReshapeLike(*bottom[]);
} template <typename Dtype>
void DecodeBBoxLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const int num = bottom[]->num();
const int bbox_dim = bottom[]->channels();
const int prior_dim = bottom[]->channels(); //decode prior box [img_id x1 y1 x2 y2]
const Dtype* prior_data = bottom[]->cpu_data();
vector<BBox> prior_bboxes;
for (int i = ; i < num; i++) {
BBox bbox;
bbox.xmin = prior_data[i*prior_dim + ];
bbox.ymin = prior_data[i*prior_dim + ];
bbox.xmax = prior_data[i*prior_dim + ];
bbox.ymax = prior_data[i*prior_dim + ];
prior_bboxes.push_back(bbox);
} // decode bbox predictions
const Dtype* bbox_data = bottom[]->cpu_data();
Dtype* bbox_pred_data = bbox_pred_.mutable_cpu_data(); DecodeBBoxesWithPrior(bbox_data,prior_bboxes,bbox_dim,bbox_mean_.cpu_data(),
bbox_std_.cpu_data(),bbox_pred_data); vector<bool> valid_bbox_flags(num,true);
// screen out mal-boxes
if (this->phase_ == TRAIN) {
for (int i = ; i < num; i++) {
const int base_index = i*bbox_dim+;
if (bbox_pred_data[base_index] > bbox_pred_data[base_index+]
|| bbox_pred_data[base_index+] > bbox_pred_data[base_index+]) {
valid_bbox_flags[i] = false;
}
}
}
// screen out high IoU boxes, to remove redundant gt boxes
if (bottom.size()== && this->phase_ == TRAIN) {
const Dtype* match_gt_boxes = bottom[]->cpu_data();
const int gt_dim = bottom[]->channels();
const float gt_iou_thr = this->layer_param_.decode_bbox_param().gt_iou_thr();
for (int i = ; i < num; i++) {
const float overlap = match_gt_boxes[i*gt_dim+gt_dim-];
if (overlap >= gt_iou_thr) {
valid_bbox_flags[i] = false;
}
}
} vector<int> valid_bbox_ids;
for (int i = ; i < num; i++) {
if (valid_bbox_flags[i]) {
valid_bbox_ids.push_back(i);
}
}
const int keep_num = valid_bbox_ids.size();
CHECK_GT(keep_num,); top[]->Reshape(keep_num, prior_dim, , );
Dtype* decoded_bbox_data = top[]->mutable_cpu_data();
for (int i = ; i < keep_num; i++) {
const int keep_id = valid_bbox_ids[i];
const int base_index = keep_id*bbox_dim+;
decoded_bbox_data[i*prior_dim] = prior_data[keep_id*prior_dim];
decoded_bbox_data[i*prior_dim+] = bbox_pred_data[base_index];
decoded_bbox_data[i*prior_dim+] = bbox_pred_data[base_index+];
decoded_bbox_data[i*prior_dim+] = bbox_pred_data[base_index+];
decoded_bbox_data[i*prior_dim+] = bbox_pred_data[base_index+];
}
} INSTANTIATE_CLASS(DecodeBBoxLayer);
REGISTER_LAYER_CLASS(DecodeBBox); } // namespace caffe
DecodeBBoxesWithPrior函数在bbox_util.cpp里实现,完成的功能就是把bounding box regression的结果对输入的prior_bbox(其实就是faster中的输入的region proposal)进行回归获得更精确的框坐标,然后存储在pred_data
template <typename Dtype>
void DecodeBBoxesWithPrior(const Dtype* bbox_data, const vector<BBox> prior_bboxes,
const int bbox_dim, const Dtype* means, const Dtype* stds,
Dtype* pred_data) {
const int num = prior_bboxes.size();
const int cls_num = bbox_dim/;
for (int i = ; i < num; i++) {
Dtype pw, ph, cx, cy;
pw = prior_bboxes[i].xmax-prior_bboxes[i].xmin+;
ph = prior_bboxes[i].ymax-prior_bboxes[i].ymin+;
cx = 0.5*(prior_bboxes[i].xmax+prior_bboxes[i].xmin);
cy = 0.5*(prior_bboxes[i].ymax+prior_bboxes[i].ymin);
for (int c = ; c < cls_num; c++) {
Dtype bx, by, bw, bh;
// bbox de-normalization
bx = bbox_data[i*bbox_dim+*c]*stds[]+means[];
by = bbox_data[i*bbox_dim+*c+]*stds[]+means[];
bw = bbox_data[i*bbox_dim+*c+]*stds[]+means[];
bh = bbox_data[i*bbox_dim+*c+]*stds[]+means[]; Dtype tx, ty, tw, th;
tx = bx*pw+cx; ty = by*ph+cy;
tw = pw*exp(bw); th = ph*exp(bh);
tx -= (tw-)/; ty -= (th-)/;
pred_data[i*bbox_dim+*c] = tx;
pred_data[i*bbox_dim+*c+] = ty;
pred_data[i*bbox_dim+*c+] = tx+tw-;
pred_data[i*bbox_dim+*c+] = ty+th-;
}
}
}
cascade DecodeBBox层的更多相关文章
- cascade rcnn论文总结
1.bouding box regression总结: rcnn使用l2-loss 首先明确l2-loss的计算规则: L∗=(f∗(P)−G∗)2,∗代表x,y,w,h 整个loss : L= ...
- Django---进阶3
目录 无名有名分组反向解析 路由分发 名称空间(了解) 伪静态(了解) 虚拟环境(了解) django版本区别 视图层 三板斧 JsonResponse对象 form表单上传文件及后端如何操作 req ...
- Django学习day04随堂笔记
每日测验 """ 今日考题 1.列举你知道的orm数据的增删改查方法 2.表关系如何判定,django orm中如何建立表关系,有什么特点和注意事项 3.请画出完整的dj ...
- 66、django之模型层(model)--多表相关操作(图书管理小练习)
前面几篇随笔的数据库增删改查操作都是在单表的操作上的,然而现实中不可能都是单表操作,更多的是多表操作,一对一,一对多,多对多的表结构才是我们经常需要处理的,本篇将带我们了解多表操作的一些相关操作.也会 ...
- 【Django】模型层说明
[Django模型层] 之前大概介绍Django的文章居然写了两篇..这篇是重点关注了Django的模型层来进行学习. ■ 模型定义 众所周知,Django中的模型定义就是定义一个类,其基本结构是这样 ...
- 一 Django模型层简介(一)
模型 django提供了一个强大的orm(关系映射模型)系统. 模型包含了你要在数据库中创建的字段信息及对数据表的一些操作 使用模型 定义好模型后,要告诉django使用这些模型,你要做的就是在配置文 ...
- 二 Djano模型层之模型字段选项
字段选项 以下参数是全部字段类型都可用的,而且是可选的 null 如果为True,Django将在数据库中将空值存储为NULL.默认值为False 对于字符串字段,如果设置了null=True意味着& ...
- WEB框架-Django框架学习(二)- 模型层
今日份整理为模型层 1.ORM简介 MVC或者MVC框架中包括一个重要的部分,就是ORM,它实现了数据模型与数据库的解耦,即数据模型的设计不需要依赖于特定的数据库,通过简单的配置就可以轻松更换数据库, ...
- django 模型层(2)
Django 模型层(2) 多表操作---模型之间的关系 1 一对一:作者----作者详细信息 2 一对多:书籍----出版社 3 多对多:书籍----作者 一 创建模型(主键(id)自动创建) 没 ...
随机推荐
- GreenPlum 大数据平台--并行备份(四)
01,并行备份(gp_dump) 1) GP同时备份Master和所有活动的Segment实例 2) 备份消耗的时间与系统中实例的数量没有关系 3) 在Master主机上备份所有DDL文件和GP相关的 ...
- java集合常用操作
收集一些常用集合操作的代码,用于治疗健忘症,:) set转list //构造Map数据 Map<String, String> map = new HashMap<String, S ...
- React.js 小书 Lesson15 - 实战分析:评论功能(二)
作者:胡子大哈 原文链接:http://huziketang.com/books/react/lesson15 转载请注明出处,保留原文链接和作者信息. 上一节我们构建了基本的代码框架,现在开始完善其 ...
- 吴恩达《Machine Learning Yearning》总结(1-10章)
1.为什么选择机器学习策略 案例:建立猫咪图像识别app 系统的优化可以有很多的方向: (1)获取更多的数据集,即更多的图片: (2)收集更多多样数据,如处于不常见的位置的猫的图,颜色奇异的猫的照片等 ...
- IE9下JQuery发送ajax请求失效
最近在做项目的时候,测试PC端网页,在IE9下会失效,不能正常的发送POST请求,经过仔细的排查,发现是IE9下JQuery发送ajax存在跨域问题. 目前有两种解决方案: 解决方案一: 设置浏览 ...
- mysql主从复制报错 :Incorrect usage of DB GRANT and GLOBAL PRIVILEGES
在配置mysql主从复制时,想通过 grant replication slave on bbs.* to 'bbs'@'192.168.1.3' identified by '123456'; 来限 ...
- Paxos、ZAB、RAFT协议
这三个都是分布式一致性协议,ZAB基于Paxos修改后用于ZOOKEEPER协议,RAFT协议出现在ZAB协议之后,与ZAB差不多,也有很大区别. 1. Paxos 分布式节点分为3种角色, Prop ...
- js系列之js简介
该系列教程都来源于:廖雪峰老师的博客 JavaScript是世界上最流行的脚本语言,因为你在电脑.手机.平板上浏览的所有的网页,以及无数基于HTML5的手机App,交互逻辑都是由JavaScript驱 ...
- 概述File i/o
1.File对象既可表示文件,也可表示目录(文件夹). 2. 创建一个File对象 File file = new File (String pathName[文件路径名]); 3.在Windows操 ...
- IMG标签与before,after伪类
在CSS中总有一些你不用不知道,用到才知道的“坑”.比如今天要谈的,把 before, after 伪类用在 <img> 标签上.嗯,实际上你用你会发现,在大多数浏览器这是无效的,dom中 ...