在论文中看到对 softmax 和 cross-entropy 的求导,一脸懵逼,故来整理整理。

以 softmax regression 为例来展示求导过程,softmax regression 可以看成一个不含隐含层的多分类神经网络,如 Fig. 1 所示。

Fig. 1 Softmax Regression.

softmax regression 的矩阵形式如 Fig. 2 所示:

Fig. 2 Matrix Form.

符号定义

如 Fig. 1 所示,\(\bm x = [x_1, x_2, x_3]^{\top}\) 表示 softmax regression 的输入,\(\bm y = [y_1, y_2, y_3]^{\top}\) 表示 softmax regression 的输出,\(\bm W\) 为权重,\(\bm b = [b_1, b_2, b_3]^{\top}\) 为偏置。

令 Fig. 2 中 softmax function 的输入为 \(z_i = W_{i, 1}x_1 + W_{i, 2}x_2 + W_{i, 3}x_3 + b_i = W_{i}\bm x + b_i\),其中 \(i= 1, 2, 3\),\(W_{i}\) 表示权重矩阵 \(\bm W\) 的第 \(i\) 行;softmax function 的输出就是整个网络的输出,即 \(\bm y\)。

Note: Fig. 1 和 Fig.2 中权重 \(W_{i, j}\) 表示第 \(i\) 个输出和第 \(j\) 个输入之间的联系,和一般的记法(即 \(W_{i, j}\) 表示第 \(i\) 个输入和第 \(j\) 个输出之间权重)相差一个转置。

用 \(m\) 表示输出的类别数,本文中 \(m = 3\)。

Note: softmax regression 指的是整个网络,softmax function 仅仅指的是激活函数。本文默认 softmax 代指激活函数,当表示整个网络时会明确说明 softmax regression。

对 softmax 求导

softmax 函数的表达式为:
\[
y_i = \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}}
\tag{1}
\]

其中 \(i= 1, 2, 3\)。由式(1)可知,\(y_i\) 与 softmax function 所有的输入 \(z_j, j = 1,2,3.\) 都有关。
softmax function 的输出对其输入求偏导:
\[
\frac{\partial y_i}{\partial z_j}
= \frac{\partial \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}}}{\partial z_j}
\tag{2}
\]

需要对式(2)中 \(i = j\) 和 \(i \not = j\) 的情况进行分别讨论。因为式(1)分子中仅含第 \(i\) 项,式(2)中如果 \(i = j\),那么导数 \(\frac{\partial e^{z_i}}{\partial z_j} = e^{z_i}\),不为 0;如果 \(i \not = j\),那导数 \(\frac{\partial e^{z_i}}{\partial z_j} = 0\)。

  • \(i = j\),则式(2)为:
    \[
    \begin{split}
    \frac{\partial y_i}{\partial z_j}
    &= \frac{\partial \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}}}{\partial z_j}
    \\ &= \frac{e^{z_i} \cdot \sum_{t = 1}^m e^{z_t} - e^{z_i} \cdot e^{z_j} }{(\sum_{t = 1}^m e^{z_t})^2}
    \\ &= \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}} - \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}} \cdot \frac{e^{z_j}}{\sum_{t = 1}^m e^{z_t}}
    \\ &=y_i(1 - y_j)
    \end{split}
    \tag{3}
    \]

当然,式(3)也可以写成 \(y_i(1 - y_i)\) 或者 \(y_j(1 - y_j)\),因为这里 \(i = j\)。

  • \(i \not = j\),则式(2)为:
    \[
    \begin{split}
    \frac{\partial y_i}{\partial z_j}
    &= \frac{\partial \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}}}{\partial z_j}
    \\ &= \frac{0\cdot \sum_{t = 1}^m e^{z_t} - e^{z_i} \cdot e^{z_j} }{(\sum_{t = 1}^m e^{z_t})^2}
    \\ &= - \frac{e^{z_i}}{\sum_{t = 1}^m e^{z_t}} \cdot \frac{e^{z_j}}{\sum_{t = 1}^m e^{z_t}}
    \\ &= -y_iy_j
    \end{split}
    \tag{4}
    \]

对 cross-entropy 求导

令 \(\bm {\hat y} = [\hat{y}_1, \hat{y}_2, \hat{y}_3]^{\top}\) 为输入 \(\bm x\) 真实类别的 one-hot encoding。

cross entropy 的定义如下:
\[
H(\bm {\hat y}, \bm y)
= - \bm {\hat y}^{\top} \log \bm y
= - \sum_{t = 1}^m \hat{y}_t\log y_t
\tag{5}
\]

对 cross entropy 求偏导:(\(\log\) 底数为 \(e\))
\[
\frac{\partial H(\bm {\hat y}, \bm y) }{\partial y_i}
= \frac{\partial [- \sum_{t = 1}^m \hat{y}_t\log y_t ]}{\partial y_i}
= - \frac{\hat{y}_i}{y_i}
\tag{6}
\]

\(\bm {\hat y}\) 是确定的值,可以理解为样本的真实 one-hot 标签,不受模型预测标签 \(\bm y\) 的影响。

对 softmax 和 cross-entropy 一起求导

\[
\begin{split}
\frac{\partial H(\bm {\hat y}, \bm y) }{\partial z_j}
&= \sum_{i = 1}^{m} \frac{\partial H(\bm {\hat y}, \bm y) }{\partial y_i} \frac{\partial y_i }{\partial z_j}
\\ &= \sum_{i = 1}^{m} -\frac{\hat{y}_i}{y_i} \cdot \frac{\partial y_i }{\partial z_j}
\\ &= \left(-\frac{\hat{y}_i}{y_i} \cdot \frac{\partial y_i }{\partial z_j}\right )_{i = j} + \sum_{i = 1 , i \not = j}^{m} -\frac{\hat{y}_i}{y_i} \cdot \frac{\partial y_i }{\partial z_j}
\\ &= -\frac{\hat{y}_j}{y_i} \cdot y_i(1-y_j) + \sum_{i = 1 , i \not = j}^{m} -\frac{\hat{y}_i}{y_i} \cdot -y_iy_j
\\ &= - \hat{y}_j + \hat{y}_jy_j + \sum_{i = 1 , i \not = j}^{m} \hat{y}_iy_j
\\ & = - \hat{y}_j + y_j\sum_{i = 1}^{m} \hat{y}_i
\\ &= y_j - \hat{y}_j
\end{split}
\tag{7}
\]

交叉熵 loss function 对 softmax function 输入 \(z_j\) 的求导结果相当简单,在 tensorflow 中,softmax 和 cross entropy 也合并成了一个函数,tf.nn.softmax_cross_entropy_with_logits,从导数求解方面看,也是有道理的。

在实际使用时,推荐使用 tensorflow 中实现的 API 去实现 softmax 和 cross entropy,而不是自己写,原因如下:

  • 都已经有 API 了,干嘛还得自己写,懒就是最好的理由;
  • softmax 因为计算了 exp(x),很容易就溢出了,比如 np.exp(800) = inf,需要做一些缩放,而 tensorflow 会帮我们处理这种数值不稳定的问题。

References

TensorFlow MNIST Dataset and Softmax Regression - Data Flair
链式法则 - 维基百科
Softmax函数与交叉熵 - 知乎

【机器学习基础】对 softmax 和 cross-entropy 求导的更多相关文章

  1. softmax交叉熵损失函数求导

    来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...

  2. softmax分类器+cross entropy损失函数的求导

    softmax是logisitic regression在多酚类问题上的推广,\(W=[w_1,w_2,...,w_c]\)为各个类的权重因子,\(b\)为各类的门槛值.不要想象成超平面,否则很难理解 ...

  3. softmax、cross entropy和softmax loss学习笔记

    之前做手写数字识别时,接触到softmax网络,知道其是全连接层,但没有搞清楚它的实现方式,今天学习Alexnet网络,又接触到了softmax,果断仔细研究研究,有了softmax,损失函数自然不可 ...

  4. 简单易懂的softmax交叉熵损失函数求导

    参考: https://blog.csdn.net/qian99/article/details/78046329

  5. softmax求导、cross-entropy求导及label smoothing

    softmax求导 softmax层的输出为 其中,表示第L层第j个神经元的输入,表示第L层第j个神经元的输出,e表示自然常数. 现在求对的导数, 如果j=i,   1 如果ji, 2 cross-e ...

  6. OO_多项式求导_单元总结

    概述: 面向对象第一单元的作业是三次难度依次递增的多项式求导.第一次作业是仅包含带符号整数和幂函数的多项式求导,例如:-1+xˆ233-xˆ06:第二次是在前面的基础上增加了三角函数的求导,例如:-1 ...

  7. OO第一单元作业——魔幻求导

    简介 本单元作业分为三次 第一次作业:需要完成的任务为简单多项式导函数的求解. 第二次作业:需要完成的任务为包含简单幂函数和简单正余弦函数的导函数的求解. 第三次作业:需要完成的任务为包含简单幂函数和 ...

  8. 【机器学习基础】交叉熵(cross entropy)损失函数是凸函数吗?

    之所以会有这个问题,是因为在学习 logistic regression 时,<统计机器学习>一书说它的负对数似然函数是凸函数,而 logistic regression 的负对数似然函数 ...

  9. softmax,softmax loss和cross entropy的区别

     版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/u014380165/article/details/77284921 我们知道卷积神经网络(CNN ...

随机推荐

  1. MySql foreach属性

    foreach属性 属性 描述 item 循环体中的具体对象.支持属性的点路径访问,如item.age,item.info.details.具体说明:在list和数组中是其中的对象,在map中是val ...

  2. 一篇文章带你了解Cloud Native

    背景 Cloud Native表面看起来比较容易理解,但是细思好像又有些模糊不清:Cloud Native和Cloud关系是啥?它用来解决什么问题?它是一个新技术还是一个新的方法?什么样的APP符合“ ...

  3. PyCharm中HTML页面CSS class名称自动完成功能失效的问题

    如果这个HTML页面带有style元素的CSS定义,那class name自动完成功能就失效了 Pycharm Version:5.03

  4. J2EE--常见面试题总结 -- (二)

    1 Spring拦截器的基本功能是什么? 拦截器是基于Java的反射机制的,是在面向切面编程的就是在你的service或者一个方法,前调用一个方法,或者在方法后调用一个方法比如动态代理就是拦截器的简单 ...

  5. java深入浅出之数据结构

    1.整形数据 byte.short.int.long,分别是1248个字节的存储量,取值范围也是依次增大的,其中int是正负21亿多: long a = 1111222233334444L:记住后面要 ...

  6. add two nums

    问题描述: 给定两个链表,计算出链表对应位置相加的和,如果和大于10要往后进位.用链表返回结果.其实上是一种大数加法.可以把一个大数倒着写存入链表,然后两个链表相加就是所需要的大数相加的和 输入 2 ...

  7. HTML 学习笔记 day one

    HTML学习笔记 day one Chapter one 网站开发基础 1.2网站的基本架构 网站的基本要素:内容,页面,超链接 动态网页和静态网页的区别在于:动态网页会自动更新,后缀名是.asp或者 ...

  8. 如何优雅的关闭Java线程池

    面试中经常会问到,创建一个线程池需要哪些参数啊,线程池的工作原理啊,却很少会问到线程池如何安全关闭的. 也正是因为大家不是很关注这块,即便是工作三四年的人,也会有因为线程池关闭不合理,导致应用无法正常 ...

  9. flask模板

    做为python web开发领域的一员,flask跟Django在很多地方用法以都是相似的,比如flask的模板 模板就是服务器端的页面,在模板中可以使用服务端的语法进行输出控制 1.模板的工作原理 ...

  10. .NET之IOC控制反转运用

    当前场景: 如果有不同的用户.使用同一个系统.而不同的客户有某些不同的需求.在不改变系统主体的情况下,可以直接使用IOC控制反转依赖搭建项目 1.添加接口层 目前里面只有一个会员的类.里面有一个登录接 ...