caffe中softmax loss源码阅读
(1) softmax loss
<1> softmax loss的函数形式为:
(1)
zi为softmax的输入,f(zi)为softmax的输出。
<2> softmax loss对其输入zj求导:
(2)
如果j==k,则zk是变量,否则zj是变量。
和的导数等于导数的和,对和中某个元素求导的话有:
(2) softmax_loss_layer.cpp中的Forward_cpu()函数:
template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// The forward pass computes the softmax prob values.
//调用softmax层的forward函数,得到对应的输出,存到prob_中
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
const Dtype* prob_data = prob_.cpu_data();
//一般loss层有两个输入blob,网络的predict blob(bottom[0])和label blob(bottom[1])
const Dtype* label = bottom[]->cpu_data();
//dim = N*C*H*W / N = C*H*W
int dim = prob_.count() / outer_num_;
//count变量是计算loss时的有效样本数
int count = ;
Dtype loss = ;
for (int i = ; i < outer_num_; ++i) {
for (int j = ; j < inner_num_; j++) {
//读取label
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
//如果该样本的label等于deploy中softmaxWithLoss中设定的参数ignore_label_,则该样本不参与前向和后向计算
if (has_ignore_label_ && label_value == ignore_label_) {
continue;
}
//判断label_value是否大于等于0
DCHECK_GE(label_value, );
//判断label_value是否小于prob_.shape(softmax_axis_)=C
DCHECK_LT(label_value, prob_.shape(softmax_axis_));
//对于softmax的输出channel,计算label_value索引对应的channel中prob的log.对应公式(1)
loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j],
Dtype(FLT_MIN)));
//有效样本数加一
++count;
}
}
//最终在训练日志中显示的loss为计算的总loss除以有效样本数
top[]->mutable_cpu_data()[] = loss / get_normalizer(normalization_, count);
if (top.size() == ) {
top[]->ShareData(prob_);
}
}
(3) softmax_loss_layer.cpp中的Backward_cpu函数:
template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
if (propagate_down[]) {
Dtype* bottom_diff = bottom[]->mutable_cpu_diff();
const Dtype* prob_data = prob_.cpu_data();
//将softmax的输出prob_复制给bottom[0]的diff(梯度) blob
caffe_copy(prob_.count(), prob_data, bottom_diff);
const Dtype* label = bottom[]->cpu_data();
int dim = prob_.count() / outer_num_;
int count = ;
for (int i = ; i < outer_num_; ++i) {
for (int j = ; j < inner_num_; ++j) {
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
for (int c = ; c < bottom[]->shape(softmax_axis_); ++c) {
bottom_diff[i * dim + c * inner_num_ + j] = ;
}
} else {
//对应公式(2),在反传梯度时,label索引对应的diff减1,其他不变。
bottom_diff[i * dim + label_value * inner_num_ + j] -= ;
++count;
}
}
}
// Scale gradient
//top[0]->cpu_diff()[0] = N
//N / count
Dtype loss_weight = top[]->cpu_diff()[] /
get_normalizer(normalization_, count);
caffe_scal(prob_.count(), loss_weight, bottom_diff);
}
}
caffe中softmax loss源码阅读的更多相关文章
- caffe中batch norm源码阅读
1. batch norm 输入batch norm层的数据为[N, C, H, W], 该层计算得到均值为C个,方差为C个,输出数据为[N, C, H, W]. <1> 形象点说,均值的 ...
- 【源码阅读】Java集合之三 - ArrayDeque源码深度解读
Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,本文是第三篇ArrayDeque. ---@pdai JDK版本 ...
- 【源码阅读】Java集合之二 - LinkedList源码深度解读
Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章; 本文是第二篇LinkedList. ---@pdai JDK版 ...
- 【源码阅读】Java集合之一 - ArrayList源码深度解读
Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,从ArrayList开始第一篇. ---@pdai JDK版本 ...
- Caffe源码阅读(1) 全连接层
Caffe源码阅读(1) 全连接层 发表于 2014-09-15 | 今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...
- caffe-windows中classification.cpp的源码阅读
caffe-windows中classification.cpp的源码阅读 命令格式: usage: classification string(模型描述文件net.prototxt) string( ...
- 源码阅读笔记 - 1 MSVC2015中的std::sort
大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...
- 源码阅读经验谈-slim,darknet,labelimg,caffe(1)
本文首先谈自己的源码阅读体验,然后给几个案例解读,选的例子都是比较简单.重在说明我琢磨的点线面源码阅读方法.我不是专业架构师,是从一个深度学习算法工程师的角度来谈的,不专业的地方请大家轻拍. 经常看别 ...
- SpringMVC源码阅读:Controller中参数解析
1.前言 SpringMVC是目前J2EE平台的主流Web框架,不熟悉的园友可以看SpringMVC源码阅读入门,它交代了SpringMVC的基础知识和源码阅读的技巧 本文将通过源码(基于Spring ...
随机推荐
- 多线程编程学习十一(ThreadPoolExecutor 详解).
一.ThreadPoolExecutor 参数说明 public ThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keep ...
- ansible-playbook流程控制-loops循环使用
1. ansible-playbook流程控制-loops循环使用 有时你想要多次重复任务.在计算机编程中,这称为循环.common ansible循环包括使用文件模块更改多个文件和/或目录的所 ...
- Day20 磁盘管理2之RAID卡
1.磁盘的基本分区Gdisk 前面我们已经了解到fdisk分区,但fdisk不支持给高于2TB的磁盘进行分区.如果有单块盘高于2TB,建议使用Gdisk进行分区. 1.使用gdisk进行磁盘分区 1. ...
- 制作mysql大数据表验证覆盖索引
昨天跟同事聊起数据表性能的问题,能不能仅用覆盖索引实现数据的汇总统计.找了一个开发环境已有的数据表进行测试,通过explain命令,能看到mysql通过覆盖索引就能实现sum的需求,而无须去读取实际行 ...
- 写论文与PPT汇报时matlab图片的背景透明处理
不少同学在使用Word写论文时,将matlab生成的图保存为jpg格式,然后粘贴到文档中.word背景为纯白色,jpg图的缺点没有显示,实际上会存在很大白边,以及放大后不清晰的问题,很影响PPT展示和 ...
- Android MediaPlayer 音频倍速播放,调整播放速度
本文链接: Android MediaPlayer 倍速播放,调整播放速度 现在市面上的很多音视频App都有倍速播放的功能,例如把播放速度调整为0.5.1.5.2倍等等. 从Android API 2 ...
- 【linux】【jenkins】自动化运维五 整合邮件提醒
1.安装插件 Email Extension Template Plugin 安装教程参考:https://www.cnblogs.com/jxd283465/p/11542680.html 2.系统 ...
- APP功能测试要点(功能测试重点)
APP功能测试要点 1.功能性测试 根据产品需求文档编写测试用例而进行测试,包括客户端的单个功能模块以及功能业务逻辑(功能交互)如:涉及输入的地方需要考虑等价类,边界值,异常或非法等 1.1 安装与卸 ...
- elastic操作-索引重命名,索引副本数修改
目前我们使用的elastic版本为2.3.5 当前版本没有直接的curl操作可以更改索引的名称,索引的副本数. 有直接更改索引副本数的api. curl -XPUT "192.168.1.1 ...
- Apache Thrift 的基本使用
Apache Thrift 的基本使用 可以先看看官网是如何介绍的 The Apache Thrift software framework, for scalable cross-language ...