A Loss Function for Learning Region Proposals

训练RPN时,只对两种anchor给予正标签:和gt_box有着最高的IoU && IoU超过0.7。如果对于

所有的gt_box,其IoU都小于0.3,则标记为负。损失函数定义如下:

其中i为一个mini-batch中某anchor的索引,pi表示其为目标的预测概率,pi*表示gt_box(正为1,否则为0)。

ti和ti*分别表示预测框的位置和gt_box框的位置。Lreg如下:

bound-box regression中各参数的计算方式为:

   (4)

其对应的SmoothL1LossLayer代码如下,整个过程分为两部分:前向计算以及后向计算(1)式的后半部分:

// ------------------------------------------------------------------
// Fast R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Ross Girshick
// ------------------------------------------------------------------ #include "caffe/fast_rcnn_layers.hpp" namespace caffe {
//SmoothL1前向计算(3)式
template <typename Dtype>
__global__ void SmoothL1Forward(const int n, const Dtype* in, Dtype* out,
Dtype sigma2) {
// f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
// |x| - 0.5 / sigma / sigma otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1.0 / sigma2) {
out[index] = 0.5 * val * val * sigma2;
} else {
out[index] = abs_val - 0.5 / sigma2;
}
}
}
//
template <typename Dtype>
void SmoothL1LossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
int count = bottom[0]->count();
caffe_gpu_sub(
count,
bottom[0]->gpu_data(), //ti
bottom[1]->gpu_data(), //ti*
diff_.mutable_gpu_data()); // d := ti-ti*
if (has_weights_) { //乘上相关的权重,对应于(1)式中的pi*,有目标时为1
// apply "inside" weights
caffe_gpu_mul(
count,
bottom[2]->gpu_data(), //pi*
diff_.gpu_data(),
diff_.mutable_gpu_data()); // d := w_in * (b0 - b1)
}
//代入计算SmoothL1
SmoothL1Forward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, diff_.gpu_data(), errors_.mutable_gpu_data(), sigma2_);
CUDA_POST_KERNEL_CHECK; if (has_weights_) { //乘上相关的权重
// apply "outside" weights
caffe_gpu_mul(
count,
bottom[3]->gpu_data(), // 1/Nreg
errors_.gpu_data(),
errors_.mutable_gpu_data()); // d := w_out * SmoothL1(w_in * (b0 - b1))
} Dtype loss;
caffe_gpu_dot(count, ones_.gpu_data(), errors_.gpu_data(), &loss);
top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num();
}
//反向计算,对smoothLoss求导
template <typename Dtype>
__global__ void SmoothL1Backward(const int n, const Dtype* in, Dtype* out,
Dtype sigma2) {
// f'(x) = sigma * sigma * x if |x| < 1 / sigma / sigma
// = sign(x) otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1.0 / sigma2) {
out[index] = sigma2 * val;
} else {
out[index] = (Dtype(0) < val) - (val < Dtype(0));
}
}
}
//
template <typename Dtype>
void SmoothL1LossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
// after forwards, diff_ holds w_in * (b0 - b1)
int count = diff_.count();
//调用反向smoothloss,diff_.gpu_data()表示x,diff_.mutable_gpu_data()表示smoothloss的导数
SmoothL1Backward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, diff_.gpu_data(), diff_.mutable_gpu_data(), sigma2_); //类似于前向
CUDA_POST_KERNEL_CHECK;
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num();
caffe_gpu_axpby(
count, // count
alpha, // alpha
diff_.gpu_data(), // x
Dtype(0), // beta
bottom[i]->mutable_gpu_diff()); // y
if (has_weights_) {
// Scale by "inside" weight
caffe_gpu_mul(
count,
bottom[2]->gpu_data(),
bottom[i]->gpu_diff(),
bottom[i]->mutable_gpu_diff());
// Scale by "outside" weight
caffe_gpu_mul(
count,
bottom[3]->gpu_data(),
bottom[i]->gpu_diff(),
bottom[i]->mutable_gpu_diff());
}
}
}
} INSTANTIATE_LAYER_GPU_FUNCS(SmoothL1LossLayer); } // namespace caffe
 

r-cnn学习(五):SmoothL1LossLayer论文与代码的结合理解的更多相关文章

  1. 10K+,深度学习论文、代码最全汇总!

    我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...

  2. R语言学习笔记(五)绘图(1)

      R是一个惊艳的图形构建平台,这也是R语言的强大之处.本文将分享R语言简单的绘图命令.   本文所使用的数据或者来自R语言自带的数据(mtcars)或者自行创建.   首先,让我们来看一个简单例子: ...

  3. [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合

    原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...

  4. Context Encoder论文及代码解读

    经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...

  5. Android JNI学习(五)——Demo演示

    本系列文章如下: Android JNI(一)——NDK与JNI基础 Android JNI学习(二)——实战JNI之“hello world” Android JNI学习(三)——Java与Nati ...

  6. TweenMax动画库学习(五)

    目录            TweenMax动画库学习(一)            TweenMax动画库学习(二)            TweenMax动画库学习(三)            Tw ...

  7. R基础学习

    R基础学习 The Art of R Programming 1.seq 产生等差数列:seq(from,to,by) seq(from,to,length) for(i in 1:length(x) ...

  8. SVG 学习<五> SVG动画

    目录 SVG 学习<一>基础图形及线段 SVG 学习<二>进阶 SVG世界,视野,视窗 stroke属性 svg分组 SVG 学习<三>渐变 SVG 学习<四 ...

  9. R语言学习 第四篇:函数和流程控制

    变量用于临时存储数据,而函数用于操作数据,实现代码的重复使用.在R中,函数只是另一种数据类型的变量,可以被分配,操作,甚至把函数作为参数传递给其他函数.分支控制和循环控制,和通用编程语言的风格很相似, ...

随机推荐

  1. C++的vector对象

    C++的vector使用 标签(空格分隔): C++ 标准库类型vector表示对象的集合,其中所有对象的类型都相同.集合中的每个对象都有一个与之对应的索引,索引用于访问对象,因为vector容纳着其 ...

  2. concat() 方法用于连接两个或多个数组。

    我们创建了三个数组,然后使用 concat() 把它们连接起来: <script type="text/javascript"> var arr = new Array ...

  3. ASP.NET Boilerplate

    I want it to be a start point for all we .NET developers, so, it will be good to develop it together ...

  4. LeetCode "448. Find All Numbers Disappeared in an Array"

    My first reaction is to have an unlimited length of bit-array, to mark existence. But if no extra me ...

  5. 错误"ORA-04091: table is mutating, trigger/function may not see it"的原因以及解决办法

    错误的原因该错误是在编写trigger时常遇到的问题,其根本原因是由于对本表的操作造成的.对于使用了for each row 的触发器,做了DML操作(delete,update,insert),还没 ...

  6. Matlab && C-Mex Round 1

    前言:本篇文章主要通过一个简单的例子程序对C-Mex进行一个初步的说明.前期的环境搭建(包括安装Matlab和gcc编译器)就不在这里赘述了. 在看文章之前,建议初学者先检查一下Matlab的mex配 ...

  7. CPU使用率终于正常了——记一次订餐系统事故处理

    引子 经过漫长的等待,儿子终于出生了.欣喜之余,就是各种手足无措,顾此失彼了.因为不懂,心里总是慌慌的,有点小毛病,恨不得一步就到医院. 婆媳育儿观念的差异,让心乱如麻的我,又成了风箱里的老鼠,两个不 ...

  8. 处理 pcap 中的 mac 二进制字节流为可读格式

    import struct # 利用 struct 处理字节流中的mac地址 适用于小端地址操作系统 def mac2str(bi_mac): mac = "" for i in ...

  9. SharePoint下载服务器资源

    使用IE浏览器

  10. Redis集群(四):主从配置二

    一.本文目的        主要介绍redis主从模式下各种情况 二.说明 主从的基本概念:Master用于写入,Slaver用于读取,不能写入或修改,一个Master可以对应多个Slaver Mas ...