BatchNorm caffe源码
1、计算的均值和方差是channel的
2、test/predict 或者use_global_stats的时候,直接使用moving average
use_global_stats 表示是否使用全部数据的统计值(该数据实在train 阶段通过moving average 方法计算得到)训练阶段设置为 fasle, 表示通过当前的minibatch 数据计算得到, test/predict 阶段使用 通过全部数据计算得到的统计值
那什么是 moving average 呢:

反向传播:


源码:(注:caffe_cpu_scale 是y=alpha*x ,这里面求滑动均值时候,alpha是滑动系数和的倒数,x是滑动均值和
template <typename Dtype>
void BatchNormLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int num = bottom[0]->shape(0);
int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_); if (bottom[0] != top[0]) {
caffe_copy(bottom[0]->count(), bottom_data, top_data);
} if (use_global_stats_) {
// use the stored mean/variance estimates.
const Dtype scale_factor = this->blobs_[2]->cpu_data()[0] == 0 ?
0 : 1 / this->blobs_[2]->cpu_data()[0];
caffe_cpu_scale(variance_.count(), scale_factor,
this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data());
caffe_cpu_scale(variance_.count(), scale_factor,
this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data());
} else {
// compute mean
caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
1. / (num * spatial_dim), bottom_data,
spatial_sum_multiplier_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
mean_.mutable_cpu_data());
} // subtract mean
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
spatial_dim, 1, -1, num_by_chans_.cpu_data(),
spatial_sum_multiplier_.cpu_data(), 1., top_data); if (!use_global_stats_) {
// compute variance using var(X) = E((X-EX)^2)
caffe_powx(top[0]->count(), top_data, Dtype(2),
temp_.mutable_cpu_data()); // (X-EX)^2
caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim,
1. / (num * spatial_dim), temp_.cpu_data(),
spatial_sum_multiplier_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
variance_.mutable_cpu_data()); // E((X_EX)^2) // compute and save moving average
this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_;
this->blobs_[2]->mutable_cpu_data()[0] += 1;
caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
int m = bottom[0]->count()/channels_;
Dtype bias_correction_factor = m > 1 ? Dtype(m)/(m-1) : 1;
caffe_cpu_axpby(variance_.count(), bias_correction_factor,
variance_.cpu_data(), moving_average_fraction_,
this->blobs_[1]->mutable_cpu_data());
} // normalize variance
caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data());
caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5),
variance_.mutable_cpu_data()); // replicate variance to input size
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
spatial_dim, 1, 1., num_by_chans_.cpu_data(),
spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data());
caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data);
// TODO(cdoersch): The caching is only needed because later in-place layers
// might clobber the data. Can we skip this if they won't?
caffe_copy(x_norm_.count(), top_data,
x_norm_.mutable_cpu_data());
} template <typename Dtype>
void BatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff;
if (bottom[0] != top[0]) {
top_diff = top[0]->cpu_diff();
} else {
caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff());
top_diff = x_norm_.cpu_diff();
}
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
if (use_global_stats_) {
caffe_div(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff);
return;
}
const Dtype* top_data = x_norm_.cpu_data();
int num = bottom[0]->shape()[0];
int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_);
// if Y = (X-mean(X))/(sqrt(var(X)+eps)), then
//
// dE(Y)/dX =
// (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y)
// ./ sqrt(var(X) + eps)
//
// where \cdot and ./ are hadamard product and elementwise division,
// respectively, dE/dY is the top diff, and mean/var/sum are all computed
// along all dimensions except the channels dimension. In the above
// equation, the operations allow for expansion (i.e. broadcast) along all
// dimensions except the channels dimension where required. // sum(dE/dY \cdot Y)
caffe_mul(temp_.count(), top_data, top_diff, bottom_diff);
caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
bottom_diff, spatial_sum_multiplier_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
mean_.mutable_cpu_data()); // reshape (broadcast) the above
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_ * num,
spatial_dim, 1, 1., num_by_chans_.cpu_data(),
spatial_sum_multiplier_.cpu_data(), 0., bottom_diff); // sum(dE/dY \cdot Y) \cdot Y
caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff); // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
caffe_cpu_gemv<Dtype>(CblasNoTrans, channels_ * num, spatial_dim, 1.,
top_diff, spatial_sum_multiplier_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CblasTrans, num, channels_, 1.,
num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0.,
mean_.mutable_cpu_data());
// reshape (broadcast) the above to make
// sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1,
batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0.,
num_by_chans_.mutable_cpu_data());
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num * channels_,
spatial_dim, 1, 1., num_by_chans_.cpu_data(),
spatial_sum_multiplier_.cpu_data(), 1., bottom_diff); // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y
caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff,
Dtype(-1. / (num * spatial_dim)), bottom_diff); // note: temp_ still contains sqrt(var(X)+eps), computed during the forward
// pass.
caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff);
} #ifdef CPU_ONLY
STUB_GPU(BatchNormLayer);
#endif INSTANTIATE_CLASS(BatchNormLayer);
REGISTER_LAYER_CLASS(BatchNorm);
} // namespace caffe
BatchNorm caffe源码的更多相关文章
- caffe源码学习之Proto数据格式【1】
前言: 由于业务需要,接触caffe已经有接近半年,一直忙着阅读各种论文,重现大大小小的模型. 期间也总结过一些caffe源码学习笔记,断断续续,这次打算系统的记录一下caffe源码学习笔记,巩固一下 ...
- Caffe源码理解2:SyncedMemory CPU和GPU间的数据同步
目录 写在前面 成员变量的含义及作用 构造与析构 内存同步管理 参考 博客:blog.shinelee.me | 博客园 | CSDN 写在前面 在Caffe源码理解1中介绍了Blob类,其中的数据成 ...
- caffe源码阅读
参考网址:https://www.cnblogs.com/louyihang-loves-baiyan/p/5149628.html 1.caffe代码层次熟悉blob,layer,net,solve ...
- Caffe源码中syncedmem文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下syncedmem文件. 1. include文件: (1).& ...
- Caffe源码中math_functions文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下math_functions文件. 1. include文件: ...
- Caffe源码中caffe.proto文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下caffe.proto文件. 在src/caffe/proto目录下有一个 ...
- Caffe源码阅读(1) 全连接层
Caffe源码阅读(1) 全连接层 发表于 2014-09-15 | 今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...
- vscode下调试caffe源码
caffe目录: ├── build -> .build_release // make生成目录,生成各种可执行bin文件,直接调用入口: ├── cmake ├── CMakeLists.tx ...
- Caffe源码中common文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中的一些重要头文件如caffe.hpp.blob.hpp等或者外部调用Caffe库使用时,一般都会in ...
随机推荐
- 洛谷 P4292 [WC2010]重建计划 解题报告
P4292 [WC2010]重建计划 题目描述 \(X\)国遭受了地震的重创, 导致全国的交通近乎瘫痪,重建家园的计划迫在眉睫.\(X\)国由\(N\)个城市组成, 重建小组提出,仅需建立\(N-1\ ...
- Mysql基本的一些查询操作
/*查询选修课程‘3-105’且成绩在60到80之间的所有记录.*/SELECT * FROM result WHERE CNO='3-105' AND GRADE > 60 AND GRADE ...
- MySQL 第七篇:视图、触发器、事务、存储过程、函数
一 视图 视图是一个虚拟表(非真实存在),其本质是[根据SQL语句获取动态的数据集,并为其命名],用户使用时只需使用[名称]即可获取结果集,可以将该结果集当做表来使用. 使用视图我们可以把查询过程中的 ...
- python之旅:并发编程之多线程
一 threading模块介绍 multiprocess模块的完全模仿了threading模块的接口,二者在使用层面,有很大的相似性,因而不再详细介绍 官网链接:https://docs.python ...
- 适用于vue项目的打印插件
此方法只适用于现代浏览器,IE10以下就别用了 // 使用时请尽量在nickTick中调用此方法 //打印 export default (refs, cb) => { let cloneN i ...
- SSO基于cas的登录
概念介绍 1.定义 CAS ( CentralAuthentication Service ) 是 Yale 大学发起的一个企业级的.开源的项目,旨在为 Web 应用系统提供一种可靠的单点登录解决方法 ...
- Java入门:基础算法之获取用户输入
本部分演示如何获取用户输入.我们使用Scanner类来得到用户输入.下面的实例代码中演示了如何获取用户输入的字符串.整数和float数据.主要用到了以下方法: 1)public String next ...
- IOS使用mkdir创建目录
在IOS真机上可以创建目录的位置只有两个Documents和Caches,如果直接在NSHomeDirectory()上创建目录,会失败,返回的errno含义为操作被禁止. 获取Caches中的一个目 ...
- javascript精雕细琢(二):++、--那点事
目录 引言 ++和--在数学运算中的计算规则 ++和--在变量引用时的计算规则 ++和--的数据转换应用 引言 对于接触JS时间不长的前端来说,刚开始要实现诸如轮播图,选项卡等小模块时,肯定会用到in ...
- Vue 嵌套数组 数组更新视图不更新
关于Vue的响应式原理,可以看官方文档或其他资料, https://www.jianshu.com/p/34de360d6035 data里定义了一个数组arr,数组的元素可以是同样格式的数组arrC ...