本文介绍Softmax运算、Softmax损失函数及其反向传播梯度计算, 内容上承接前两篇博文 损失函数 & 手推反向传播公式

Softmax 梯度

设有K类, 那么期望标签y形如\([0,0,...0,1,0...0]^T\)的one-hot的形式. softmax层的输出为\([a_1,a_2,...,a_j,...a_K]^T\), 其中第j类的softmax输出为:

\[\begin{align}
a_{j} &= \frac{\exp(z_{j})}{\sum_{k=1}^K \exp(z_{k})} \forall j\in 1...K \\
{\partial a_{j}\over \partial z_{j} } &= {\exp(z_{j})\cdot(\Sigma - \exp(z_{j}) )\over \Sigma^2} = a_j(1 - a_j) \\
{\partial a_{k}\over \partial z_{j} } &= { - \exp(z_{k}) \cdot \exp(z_{j}) \over \Sigma^2} = -a_j a_k \tag{$k\ne j$}
\end{align}
\]

如果是全连接的DNN,那么有: \(z_{j}^{l+1}=\sum_i w_{ij} a_{i}^{l}+b_j^{l+1}\)

\(a_j^{l+1}\)可以解释成观察到的数据 \(a^l\) 属于类别 j 的概率,或者称作似然 (Likelihood)。

求输出对输入的梯度\(\partial a\over \partial z\):

\[\begin{align}
{\partial a\over \partial z_k}=
\begin{bmatrix}
{\partial a_1\over \partial z_k} \\
\vdots \\
{\partial a_k\over \partial z_k} \\
\vdots \\
{\partial a_K\over \partial z_k}
\end{bmatrix}
=
\begin{bmatrix}
-a_1 \\
\vdots \\
(1-a_k) \\
\vdots \\
-a_K
\end{bmatrix}a_k
=
(\begin{bmatrix}
0 \\
\vdots \\
1 \\
\vdots \\
0
\end{bmatrix}
-a)a_k
\end{align}
\]

因此损失对输入的梯度为\({\partial E\over \partial z}\):

\[{\partial E\over \partial z_k}={\partial E\over \partial a}{\partial a\over \partial z_k}=({\partial E\over \partial a_k} - [{\partial E\over \partial a}]^T a)a_k \\
{\partial E\over \partial z}={\partial E\over \partial a}{\partial a\over \partial z}=({\partial E\over \partial a} - [{\partial E\over \partial a}]^T a)⊙ a
\]

对应的 Caffe 中的SoftmaxLayer的梯度反向传播计算实现代码为:

# dot 表示矩阵乘法, * 表示按对应元素相乘
bottom_diff = (top_diff - dot(top_diff, top_data)) * top_data

Softmax loss 梯度

单样本的损失函数为:

\[E = -\sum^K_{k}y_k\log(a_{k}) \\
{\partial E\over \partial a_j} = -\sum^K_{k}{y_k\over a_k}\cdot {\partial a_k\over \partial a_j}=-{y_j\over a_j}
\]

接下来求E对w,b的梯度, 过程与反向传播的通用梯度计算公式相同, 这里指定了具体的激活函数(softmax)与损失函数:

\[\begin{align}
{\partial E\over \partial b_j^{l+1}} &= {\partial E\over \partial z_j^{l+1}} = \sum_k{\partial E\over \partial a_k^{l+1}} \cdot {\partial a_k^{l+1}\over \partial z_j^{l+1}} \\
&=-{y_j^{l+1}\over a_j^{l+1}} \cdot a_j^{l+1}(1 - a_j^{l+1})+\sum_{k\ne j}[-{y_k^{l+1}\over a_k^{l+1}} \cdot -a_j^{l+1} a_k^{l+1}] \\
&= -y_j^{l+1}+y_j^{l+1} a_j^{l+1} +\sum_{k\ne j}y_k^{l+1}a_j^{l+1} \\
&= a_j^{l+1}-y_j^{l+1} \\
{\partial E\over \partial w_{ij}^{l+1}} &= {\partial E\over \partial z_j^{l+1}} \cdot {\partial z_j^{l+1}\over w_{ij}^{l+1}}=(a_j^{l+1}-y_j^{l+1})a_i^l
\end{align}
\]

对应的 Caffe 中的SoftmaxWithLossLayer的梯度反向传播计算实现为(\({\partial E\over \partial z}\)):

# prob_data 为前向传播时softmax的结果, label_data 是标签的one-hot表示
bottom_diff = prob_data - label_data

参考[1][2][3]


  1. softmax的log似然代价函数(公式求导) https://blog.csdn.net/u014313009/article/details/51045303 ↩︎

  2. Softmax与SoftmaxWithLoss原理及代码详解 https://blog.csdn.net/u013010889/article/details/76343758 ↩︎

  3. 数值计算稳定性 http://freemind.pluskid.org/machine-learning/softmax-vs-softmax-loss-numerical-stability/ ↩︎

Softmax 损失-梯度计算的更多相关文章

  1. 实现属于自己的TensorFlow(二) - 梯度计算与反向传播

    前言 上一篇中介绍了计算图以及前向传播的实现,本文中将主要介绍对于模型优化非常重要的反向传播算法以及反向传播算法中梯度计算的实现.因为在计算梯度的时候需要涉及到矩阵梯度的计算,本文针对几种常用操作的梯 ...

  2. 多类 SVM 的损失函数及其梯度计算

    CS231n Convolutional Neural Networks for Visual Recognition -- optimization 1. 多类 SVM 的损失函数(Multicla ...

  3. Theano学习-梯度计算

    1. 计算梯度 创建一个函数 \(y\) ,并且计算关于其参数 \(x\) 的微分. 为了实现这一功能,将使用函数 \(T.grad\) . 例如:计算 \(x^2\) 关于参数 \(x\) 的梯度. ...

  4. 机器学习进阶-图像梯度计算-scharr算子与laplacian算子(拉普拉斯) 1.cv2.Scharr(使用scharr算子进行计算) 2.cv2.laplician(使用拉普拉斯算子进行计算)

    1. cv2.Scharr(src,ddepth, dx, dy), 使用Scharr算子进行计算 参数说明:src表示输入的图片,ddepth表示图片的深度,通常使用-1, 这里使用cv2.CV_6 ...

  5. pytorch 反向梯度计算问题

    计算如下\begin{array}{l}{x_{1}=w_{1} * \text { input }} \\ {x_{2}=w_{2} * x_{1}} \\ {x_{3}=w_{3} * x_{2} ...

  6. 优化梯度计算的改进的HS光流算法

    前言 在经典HS光流算法中,图像中两点间的灰度变化被假定为线性的,但实际上灰度变化是非线性的.本文详细分析了灰度估计不准确造成的偏差并提出了一种改进HS光流算法,这种算法可以得到较好的计算结果,并能明 ...

  7. TensorFlow 学习(八)—— 梯度计算(gradient computation)

    maxpooling 的 max 函数关于某变量的偏导也是分段的,关于它就是 1,不关于它就是 0: BP 是反向传播求关于参数的偏导,SGD 则是梯度更新,是优化算法: 1. 一个实例 relu = ...

  8. 理解自动梯度计算autograd

    理解自动求导 例子 def f(x): a = x * x b = x * a c = a + b return c 基于图理解 代码实现 def df(x): # forward pass a = ...

  9. [图解tensorflow源码] MatMul 矩阵乘积运算 (前向计算,反向梯度计算)

随机推荐

  1. C_数据结构_栈

    # include <stdio.h> # include <malloc.h> # include <stdlib.h> typedef struct Node ...

  2. php7与之前的区别和更新【转】

    http://blog.csdn.net/u011957758/article/details/73320083 本文是一篇讲座听后+后续研究的总结. 话说当年追时髦,php7一出就给电脑立马装上了, ...

  3. 【个人阅读】软件工程M1/M2阶段总结

    这次作业是好久以前布置的,由于学期末课程设计任务比较重,我在完善M2阶段的代码的同时又忙于数据库的实现和编译器的实现,一度感觉忙得透不过气来....到这些都基本完成的时候,会看自己以前的阅读心得,觉得 ...

  4. 实验十一 团队作业7—团队项目设计完善&编码测试

    实验十一 团队作业7—团队项目设计完善&编码测试 实验时间 2018-6-8 Deadline: 2018-6-20 10:00,以团队随笔博文提交至班级博客的时间为准. 评分标准: 按时交 ...

  5. JEECG--去掉(增加)登陆页面验证码功能 - CSDN博客

    JEECG--去掉(增加)登陆页面验证码功能 - CSDN博客https://blog.csdn.net/KooKing_L/article/details/79711379

  6. 转载 linux常用的监控命令工具

    工具 简单介绍top 查看进程活动状态以及一些系统状况vmstat 查看系统状态.硬件和系统信息等iostat 查看CPU 负载,硬盘状况sar 综合工具,查看系统状况mpstat 查看多处理器状况n ...

  7. SQLSERVER 创建对Oracle数据库的DBlink以及查询使用

    1. 与针对oracle数据库一样, 在sqlserver中创建对oracle数据库的dblink 安全性上面也进行定义(貌似不需要跟访问字符串只需要填一个即可) 发现有的版本改注册表不管用 还得修改 ...

  8. ionic2/3注册安卓返回

    如果使用了 this.app.getRootNav().push()以及this.navCtrl.push();   则在注册安卓返回键的时候   registerBackButtonAction() ...

  9. JavaScript从入门到精通

    第一(基本语法) if(condition1){ expression1; }else if(condition2){ expression2; }else{ expression3; } switc ...

  10. python之pygal:掷一个骰子统计次数并以直方图形式显示

    源码如下: # pygal包:生成可缩放的矢量图形文件,可自适应不同尺寸的屏幕显示 # 安装:python -m pip intall pygal-2.4.0-py2.py3-none-any.whl ...