softmax分类器+cross entropy损失函数的求导
softmax是logisitic regression在多酚类问题上的推广,\(W=[w_1,w_2,...,w_c]\)为各个类的权重因子,\(b\)为各类的门槛值。不要想象成超平面,否则很难理解,如果理解成每个类的打分函数,则会直观许多。预测时我们把样本分配到得分最高的类。
Notations:
- \(x\):输入向量,\(d\times 1\)列向量,\(d\)是feature数
- \(W\):权重矩阵,\(c\times d\)矩阵,\(c\)是label数
- \(b\):每个类对应超平面的偏置组成的向量, \(c\times 1\)列向量
- \(z=Wx+b\):线性分类器输出, \(c\times 1\)列向量
- \(\hat{y}\):softmax函数输出, \(c\times 1\)列向量
- 记\(\vec{e}_j=[0,...,1,...,0]^T\in\mathbb{R}^{c\times 1}\),其中\(1\)出现在第\(j\)个位置
- \(1_c\)表示一个全\(1\)的\(c\)维列向量
- \(y\):我们要拟合的目标变量,是一个one-hot vector(只有一个1,其余均为0),也是 \(c\times 1\)列向量 。 我们将其转置,表示为一个列向量:
\[y=[0,...,1,...,0]^T\]
他们之间的关系:
\[\left\{\begin{aligned}&z=Wx+b\\& \hat{y}=\mathrm{softmax}(z)=\frac{exp(z)}{1_c^Texp(z)} \end{aligned}\right.\]
cross-entropy error定义为:
\[ CE(z) = -y^Tlog(\hat{y}) \]
因为\(y\)是一个one-hot vector(即只有一个位置为1),假设\(y_k=1\),那么上式等于\(-log(\hat{y}_k)=-log(\frac{exp(z_k)}{\sum\limits_i exp(z_i)})=-z_k+log(\sum\limits_i exp(z_i))\)
依据chain rule有:
\[ \begin{aligned}\frac{\partial CE(z)}{\partial W_{ij}}
&=tr\bigg(\big(\frac{\partial CE(z)}{\partial z}\big)^T\frac {\partial z}{\partial W_{ij}}\bigg)\\
&=tr\bigg( \big(\frac{\partial \hat{y}}{\partial z}\cdot\frac{\partial CE(z)}{\partial \hat{y}}\big)^T\frac {\partial z}{\partial W_{ij}} \bigg)\end{aligned}\]
注:这里我用了Denominator layout
,因此链式法则是从右往左的。
我们一个一个来求。
\[\begin{equation}\begin{aligned}\frac{\partial \hat{y}}{\partial z}&=\frac{\partial ( \frac{exp(z)}{1_c^Texp(z)})}{\partial z}\\&= \frac{1}{1_c^Texp(z)}\frac{\partial exp(z)}{\partial z}+ \frac{\partial (\frac{1}{1_c^Texp(z)})}{\partial z}( exp(z) )^T\\&= \frac{1}{1_c^Texp(z)}diag(exp(z))-\frac{1}{(1_c^Texp(z))^2}exp(z)exp(z)^T\\&=diag(\frac{exp(z)}{1_c^Texp(z)})-\frac{exp(z)}{1_c^Texp(z)}\cdot (\frac{exp(z)}{1_c^Texp(z)})^T\\&=diag(\mathrm{ softmax}(z))- \mathrm{ softmax}(z) \mathrm{ softmax}(z)^T\\&=diag(\hat{y})-\hat{y}\hat{y}^T \end{aligned}\label{eq1}\end{equation}\]
注:上述求导过程使用了Denominator layout
。
设$a=a( \boldsymbol{ x}),\boldsymbol{u}= \boldsymbol{u}( \boldsymbol{x}) \(,这里\) \boldsymbol{ x}\(特意加粗表示是列向量,\)a\(没加粗表示是一个标量函数,\) \boldsymbol{u}\(加粗表示是一个向量函数。在`Numerator layout`下,\)\frac{\partial a \boldsymbol{u}}{ \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+ \boldsymbol{u}\frac{\partial a}{\partial \boldsymbol{x}} \(,而在`Denominator layout`下,则为\)\frac{\partial a \boldsymbol{u}}{\partial \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+\frac{\partial a}{\partial \boldsymbol{x}} \boldsymbol{u}^T$,对比可知上述推导用的实际是Denominator layout
。
以下推导均采用 Denominator layout,这样的好处是我们用梯度更新权重时不需要对梯度再转置。
\[\begin{equation}\frac{\partial CE(z)}{\partial \hat{y}}=\frac{\partial log(\hat{y})}{\partial \hat{y}}\cdot \frac{\partial (-y^Tlog(\hat{y}))}{\partial log(\hat{y})}=\big(diag(\hat{y})\big)^{-1}\cdot(-y)\label{eq2}\end{equation}\]
\(z\)的第\(k\)个分量可以表示为:\(z_k=\sum\limits_j W_{kj}x_j+b_k\),因此
\[\begin{equation}\frac{\partial z}{\partial W_{ij}} =\begin{bmatrix}\frac{\partial z_1}{\partial W_{ij}}\\\vdots\\\frac{\partial z_c}{\partial W_{ij}}\end{bmatrix}=[0,\cdots, x_j,\cdots, 0]^T=x_j \vec{e}_i\label{eq3}\end{equation}\]
其中\(x_j\)是向量\(x\)的第\(j\)个元素,为标量,它出现在第\(i\)行。
综合\(\eqref{eq1},\eqref{eq2},\eqref{eq3}\),我们有
\[ \begin{aligned}\frac{\partial CE(z)}{\partial W_{ij}}&=tr\bigg(\big( (diag(\hat{y})-\hat{y}\hat{y}^T)\cdot (diag(\hat{y}))^{-1} \cdot (-y) \big)^T\cdot x_j \vec{e}_i \bigg)\\&=tr\bigg(\big( \hat{y}\cdot (1_c^Ty)-y\big)^T\cdot x_j \vec{e}_i \bigg)\\&=(\hat{y}-y)^T\cdot x_j \vec{e}_i={err}_ix_j\end{aligned}\]
其中\({err}_i=(\hat{y}-y)_i\)表示残差向量的第\(i\)项
我们可以把上式改写为
\[ \frac{\partial CE(z)}{\partial W}=(\hat{y}-y)\cdot x^T \]
同理可得
\[ \frac{\partial CE(z)}{\partial b}=(\hat{y}-y) \]
那么在进行随机梯度下降的时候,更新式就是:
\[ \begin{aligned}&W \leftarrow W - \lambda (\hat{y}-y)\cdot x^T \\&b \leftarrow b - \lambda (\hat{y}-y)\end{aligned}\]
其中\(\lambda\)是学习率
softmax分类器+cross entropy损失函数的求导的更多相关文章
- 【转载】softmax的log似然代价函数(求导过程)
全文转载自:softmax的log似然代价函数(公式求导) 在人工神经网络(ANN)中,Softmax通常被用作输出层的激活函数.这不仅是因为它的效果好,而且因为它使得ANN的输出值更易于理解.同时, ...
- softmax、cross entropy和softmax loss学习笔记
之前做手写数字识别时,接触到softmax网络,知道其是全连接层,但没有搞清楚它的实现方式,今天学习Alexnet网络,又接触到了softmax,果断仔细研究研究,有了softmax,损失函数自然不可 ...
- 【机器学习基础】对 softmax 和 cross-entropy 求导
目录 符号定义 对 softmax 求导 对 cross-entropy 求导 对 softmax 和 cross-entropy 一起求导 References 在论文中看到对 softmax 和 ...
- 【机器学习基础】交叉熵(cross entropy)损失函数是凸函数吗?
之所以会有这个问题,是因为在学习 logistic regression 时,<统计机器学习>一书说它的负对数似然函数是凸函数,而 logistic regression 的负对数似然函数 ...
- softmax 损失函数求导过程
前言:softmax中的求导包含矩阵与向量的求导关系,记录的目的是为了回顾. 下图为利用softmax对样本进行k分类的问题,其损失函数的表达式为结构风险,第二项是模型结构的正则化项. 首先,每个qu ...
- softmax交叉熵损失函数求导
来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...
- softmax,softmax loss和cross entropy的区别
版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/u014380165/article/details/77284921 我们知道卷积神经网络(CNN ...
- 关于交叉熵损失函数Cross Entropy Loss
1.说在前面 最近在学习object detection的论文,又遇到交叉熵.高斯混合模型等之类的知识,发现自己没有搞明白这些概念,也从来没有认真总结归纳过,所以觉得自己应该沉下心,对以前的知识做一个 ...
- 【机器学习】BP & softmax求导
目录 一.BP原理及求导 二.softmax及求导 一.BP 1.为什么沿梯度方向是上升最快方向 根据泰勒公式对f(x)在x0处展开,得到f(x) ~ f(x0) + f'(x0)(x-x0) ...
随机推荐
- 自己用C语言写PIC32 serial bootloader
了解更多关于bootloader 的C语言实现,请加我QQ: 1273623966 (验证信息请填 bootloader),欢迎咨询或定制bootloader(在线升级程序). 从15年12月份以来我 ...
- IAR调节字体大小
在主面板上点击tools->Options,然后点开Editor,选择下面的Colors and Fonts选项,最后选右上方的Font,选择要设置的字体就OK了.
- IT在线学习网站总结
以下是我自己做软件过程中发现的一些不错的IT学习网站,个人感觉比较受用,故总结出来以供IT爱好者一起学习: www.maiziedu.com 麦子学院 www.jikexueyuan.com 极客学 ...
- NSQ的消息订阅发布测试
在测试NSQ的Quick Start发现这样一个问题,就是同时只能有一个订阅实例 $ nsq_to_file --topic=test --output- 当存在两个实例时则消息会被发送给其中的一个实 ...
- asp.net core 阿里云消息服务(Message Service,原MQS)发送接口的实现
最近在后台处理订单统计等相关功能用到了大力的mqs,由于官方没有实现asp.net core的sdk,这里简单实现了发送信息的功能,有兴趣的可以参考实现其他相关功能 using System;usin ...
- PHP在linux上执行外部命令
PHP在linux上执行外部命令 一.PHP中调用外部命令介绍二.关于安全问题三.关于超时问题四.关于PHP运行linux环境中命令出现的问题 一.PHP中调用外部命令介绍在PHP中调用外部命令,可以 ...
- genymotion启动虚拟机遇到问题解决方法步骤
通过在不做任务设置时启动genymotion,会遇到一些问题: 会弹出类似如下问题: 要解决这样问题,首先要知道是什么问题,一般按提示在VitualBox中启动虚拟机就可以知道是什么问题. “To f ...
- C89和C99区别--简单总结
(1)对数组的增强 可变长数组 C99中,程序员声明数组时,数组的维数可以由任一有效的整型表达式确定,包括只在运行时才能确定其值的表达式,这类数组就叫做可变长数组,但是只有局部数组才可以是变长的.可变 ...
- 【面向对象版】HashMap(增删改查)
前言: 关于什么是HashMap,HashMap可以用来做些什么,这些定义类的描述,请参照[简易版]HashMap(增删改查)的内容. 这章节主要是面向实例,直接进行HashMap(增删改查)的演示. ...
- nosql数据库学习
1.MongoDB 介绍 MongoDB是一个基于分布式文件存储的数据库.由C++语言编写.主要解决的是海量数据的访问效率问题,为WEB应用提供可扩展的高性能数据存储解决方案.当数据量达到50GB以上 ...