1.前向传播:

template <typename Dtype>
void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[]->cpu_data();
Dtype* top_data = top[]->mutable_cpu_data();
Dtype* scale_data = scale_.mutable_cpu_data();
int channels = bottom[]->shape(softmax_axis_);
int dim = bottom[]->count() / outer_num_; //dim表示要分类的类别数,count()得到的是总共的输入Blob数,outer_num_得到的是是每一类的Blob数
caffe_copy(bottom[]->count(), bottom_data, top_data); //先将输入拷贝到输出缓冲区
// We need to subtract the max to avoid numerical issues, compute the exp,
// and then normalize,减去最大值,避免数值问题,计算指数,归一化
for (int i = ; i < outer_num_; ++i) {
// 初始化scale_的data域为第一个平面,其中scale用来存放临时计算结果
caffe_copy(inner_num_, bottom_data + i * dim, scale_data);
for (int j = ; j < channels; j++) {
for (int k = ; k < inner_num_; k++) {
scale_data[k] = std::max(scale_data[k],
bottom_data[i * dim + j * inner_num_ + k]);
}
}
// 输出缓冲区减去最大值
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_,
, -., sum_multiplier_.cpu_data(), scale_data, ., top_data);
// exponentiation
caffe_exp<Dtype>(dim, top_data, top_data);
// sum after exp
caffe_cpu_gemv<Dtype>(CblasTrans, channels, inner_num_, .,
top_data, sum_multiplier_.cpu_data(), ., scale_data);
// division
for (int j = ; j < channels; j++) {
caffe_div(inner_num_, top_data, scale_data, top_data);
top_data += inner_num_;
}
}
}

一般的我们有top[0]来存放数据,top[1]来存放标签(对于bottom也一样)

2.反向传播:

template <typename Dtype>
void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[]->cpu_diff();
const Dtype* top_data = top[]->cpu_data();
Dtype* bottom_diff = bottom[]->mutable_cpu_diff();
Dtype* scale_data = scale_.mutable_cpu_data();
int channels = top[]->shape(softmax_axis_);
int dim = top[]->count() / outer_num_;
caffe_copy(top[]->count(), top_diff, bottom_diff); //先用top_diff初始化bottom_diff
for (int i = ; i < outer_num_; ++i) {
// 计算top_diff和top_data的点积,然后从bottom_diff中减去该值
for (int k = ; k < inner_num_; ++k) {
scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
bottom_diff + i * dim + k, inner_num_,
top_data + i * dim + k, inner_num_);
}
// 减值
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, ,
-., sum_multiplier_.cpu_data(), scale_data, ., bottom_diff + i * dim);
}
// 逐点相乘
caffe_mul(top[]->count(), bottom_diff, top_data, bottom_diff);
}

解释:

补充:最后部分,Zi!=Zj和Zi=Zj部分写反了,大家注意一下~

caffe中 softmax 函数的前向传播和反向传播的更多相关文章

  1. 机器学习(ML)八之正向传播、反向传播和计算图,及数值稳定性和模型初始化

    正向传播 正向传播的计算图 通常绘制计算图来可视化运算符和变量在计算中的依赖关系.下图绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出.可以看到,图中箭头方向大多是向右和向上,其中 ...

  2. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  3. caffe中的前向传播和反向传播

    caffe中的网络结构是一层连着一层的,在相邻的两层中,可以认为前一层的输出就是后一层的输入,可以等效成如下的模型 可以认为输出top中的每个元素都是输出bottom中所有元素的函数.如果两个神经元之 ...

  4. caffe中softmax源码阅读

    (1) softmax函数                                      (1) 其中,zj 是softmax层的bottom输入, f(zj)是softmax层的top输 ...

  5. BP原理 - 前向计算与反向传播实例

    Outline 前向计算 反向传播 很多事情不是需要聪明一点,而是需要耐心一点,踏下心来认真看真的很简单的. 假设有这样一个网络层: 第一层是输入层,包含两个神经元i1 i2和截距b1: 第二层是隐含 ...

  6. 反向传播算法(前向传播、反向传播、链式求导、引入delta)

    参考链接: 一文搞懂反向传播算法

  7. caffe中softmax loss源码阅读

    (1) softmax loss <1> softmax loss的函数形式为:     (1) zi为softmax的输入,f(zi)为softmax的输出. <2> sof ...

  8. 前向传播和反向传播实战(Tensor)

    前面在mnist中使用了三个非线性层来增加模型复杂度,并通过最小化损失函数来更新参数,下面实用最底层的方式即张量进行前向传播(暂不采用层的概念). 主要注意点如下: · 进行梯度运算时,tensorf ...

  9. caffe中python接口的使用

    下面是基于我自己的接口,我是用来分类一维数据的,可能不具通用性: (前提,你已经编译了caffe的python的接口) 添加 caffe塻块的搜索路径,当我们import caffe时,可以找到. 对 ...

随机推荐

  1. 【POI每日题解 #9】SKA-Piggy Banks

    题目链接 题意: 有一棵环套树 求最少从多少个节点出发能沿边走过整棵树 环套树 并查集求联通块 有几块就砸几个 太简单不发代码了 不过某大佬的环套树找环dfs让我研究了好久… 贴一下以Orz #inc ...

  2. go vendor目录

    参考 https://blog.csdn.net/u010649766/article/details/80327035 那么查找依赖包路径的解决方案如下: 当前包下的vendor目录. 向上级目录查 ...

  3. Pycharm激活、配置以及快捷方式 | 图解

    访问flyai.club,一键创建你的人工智能项目 来源 | Python (python6359) Pycharm可以去官网下载 Pycharm的安装激活 jar包的目的就是让截获截止时间并骗过py ...

  4. 如何将Anaconda更新到想要的python版本

    最近用Anaconda比较多,因为它里面的包很全啊.如果下个原生的python,要用的时候得自己一个个装. 但是有些包又互相依赖,一个个装的时候实在很抓狂.懒人就想到了anaconda这种套装集合了. ...

  5. mysql自定义函数与过程中写法的注意事项

    BEGIN #Routine body goes here... /* update szzx_goods_common set gc_id=i where gc_name=(SELECT gc_na ...

  6. Java_Mybatis_注解代理写法

    Mybatis的开发方式其实有3种: 1. 原始Dao开发(就是把mapper接口.映射文件和实现类都一并开发) 2. xml代理(就是只实现mapper接口和映射文件) 3.注解代理(就是只实现ma ...

  7. 离线方式部署Ambari2.6.0.0

    Hadoop生态圈-离线方式部署Ambari2.6.0.0 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我现在所在的公司用的是CDH管理Hadoop集群,前端时间去面试时发现很多 ...

  8. PHP composer-setup安装遇到的openssl extension is missing

    问题描述: 安装完成php-7.1.17后,安装composer出现以下错误 [root@localhost src]# curl -sS https://getcomposer.org/instal ...

  9. mysql常用sql汇总

    给一张表新增一个字段 ALTER table student add zz INT() DEFAULT COMMENT '0是授权 1未授权' 给表student 新增一个zz的字段 默认是0 后面是 ...

  10. SQL语句(一)SQL和数据库数据表的创建

    SQL的组成 (1) 数据定义语言DDL(Data Definition Language) 用于数据库和数据表的创建.修改和删除等操作 CREATE (create) 创建数据库.数据表 ALTER ...