softmax是logisitic regression在多酚类问题上的推广,\(W=[w_1,w_2,...,w_c]\)为各个类的权重因子,\(b\)为各类的门槛值。不要想象成超平面,否则很难理解,如果理解成每个类的打分函数,则会直观许多。预测时我们把样本分配到得分最高的类。

Notations:

  • \(x\):输入向量,\(d\times 1\)列向量,\(d\)是feature数
  • \(W\):权重矩阵,\(c\times d\)矩阵,\(c\)是label数
  • \(b\):每个类对应超平面的偏置组成的向量, \(c\times 1\)列向量
  • \(z=Wx+b\):线性分类器输出, \(c\times 1\)列向量
  • \(\hat{y}\):softmax函数输出, \(c\times 1\)列向量
  • 记\(\vec{e}_j=[0,...,1,...,0]^T\in\mathbb{R}^{c\times 1}\),其中\(1\)出现在第\(j\)个位置
  • \(1_c\)表示一个全\(1\)的\(c\)维列向量
  • \(y\):我们要拟合的目标变量,是一个one-hot vector(只有一个1,其余均为0),也是 \(c\times 1\)列向量 。 我们将其转置,表示为一个列向量:
    \[y=[0,...,1,...,0]^T\]

他们之间的关系:
\[\left\{\begin{aligned}&z=Wx+b\\& \hat{y}=\mathrm{softmax}(z)=\frac{exp(z)}{1_c^Texp(z)} \end{aligned}\right.\]

cross-entropy error定义为:
\[ CE(z) = -y^Tlog(\hat{y}) \]
因为\(y\)是一个one-hot vector(即只有一个位置为1),假设\(y_k=1\),那么上式等于\(-log(\hat{y}_k)=-log(\frac{exp(z_k)}{\sum\limits_i exp(z_i)})=-z_k+log(\sum\limits_i exp(z_i))\)

依据chain rule有:
\[ \begin{aligned}\frac{\partial CE(z)}{\partial W_{ij}}
&=tr\bigg(\big(\frac{\partial CE(z)}{\partial z}\big)^T\frac {\partial z}{\partial W_{ij}}\bigg)\\
&=tr\bigg( \big(\frac{\partial \hat{y}}{\partial z}\cdot\frac{\partial CE(z)}{\partial \hat{y}}\big)^T\frac {\partial z}{\partial W_{ij}} \bigg)\end{aligned}\]
注:这里我用了Denominator layout,因此链式法则是从右往左的。

我们一个一个来求。
\[\begin{equation}\begin{aligned}\frac{\partial \hat{y}}{\partial z}&=\frac{\partial ( \frac{exp(z)}{1_c^Texp(z)})}{\partial z}\\&= \frac{1}{1_c^Texp(z)}\frac{\partial exp(z)}{\partial z}+ \frac{\partial (\frac{1}{1_c^Texp(z)})}{\partial z}( exp(z) )^T\\&= \frac{1}{1_c^Texp(z)}diag(exp(z))-\frac{1}{(1_c^Texp(z))^2}exp(z)exp(z)^T\\&=diag(\frac{exp(z)}{1_c^Texp(z)})-\frac{exp(z)}{1_c^Texp(z)}\cdot (\frac{exp(z)}{1_c^Texp(z)})^T\\&=diag(\mathrm{ softmax}(z))- \mathrm{ softmax}(z) \mathrm{ softmax}(z)^T\\&=diag(\hat{y})-\hat{y}\hat{y}^T \end{aligned}\label{eq1}\end{equation}\]
注:上述求导过程使用了Denominator layout
设$a=a( \boldsymbol{ x}),\boldsymbol{u}= \boldsymbol{u}( \boldsymbol{x}) \(,这里\) \boldsymbol{ x}\(特意加粗表示是列向量,\)a\(没加粗表示是一个标量函数,\) \boldsymbol{u}\(加粗表示是一个向量函数。在`Numerator layout`下,\)\frac{\partial a \boldsymbol{u}}{ \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+ \boldsymbol{u}\frac{\partial a}{\partial \boldsymbol{x}} \(,而在`Denominator layout`下,则为\)\frac{\partial a \boldsymbol{u}}{\partial \boldsymbol{x}}=a\frac{\partial \boldsymbol{u}}{\partial \boldsymbol{x}}+\frac{\partial a}{\partial \boldsymbol{x}} \boldsymbol{u}^T$,对比可知上述推导用的实际是Denominator layout
以下推导均采用 Denominator layout,这样的好处是我们用梯度更新权重时不需要对梯度再转置。

\[\begin{equation}\frac{\partial CE(z)}{\partial \hat{y}}=\frac{\partial log(\hat{y})}{\partial \hat{y}}\cdot \frac{\partial (-y^Tlog(\hat{y}))}{\partial log(\hat{y})}=\big(diag(\hat{y})\big)^{-1}\cdot(-y)\label{eq2}\end{equation}\]
\(z\)的第\(k\)个分量可以表示为:\(z_k=\sum\limits_j W_{kj}x_j+b_k\),因此
\[\begin{equation}\frac{\partial z}{\partial W_{ij}} =\begin{bmatrix}\frac{\partial z_1}{\partial W_{ij}}\\\vdots\\\frac{\partial z_c}{\partial W_{ij}}\end{bmatrix}=[0,\cdots, x_j,\cdots, 0]^T=x_j \vec{e}_i\label{eq3}\end{equation}\]
其中\(x_j\)是向量\(x\)的第\(j\)个元素,为标量,它出现在第\(i\)行。
综合\(\eqref{eq1},\eqref{eq2},\eqref{eq3}\),我们有
\[ \begin{aligned}\frac{\partial CE(z)}{\partial W_{ij}}&=tr\bigg(\big( (diag(\hat{y})-\hat{y}\hat{y}^T)\cdot (diag(\hat{y}))^{-1} \cdot (-y) \big)^T\cdot x_j \vec{e}_i \bigg)\\&=tr\bigg(\big( \hat{y}\cdot (1_c^Ty)-y\big)^T\cdot x_j \vec{e}_i \bigg)\\&=(\hat{y}-y)^T\cdot x_j \vec{e}_i={err}_ix_j\end{aligned}\]
其中\({err}_i=(\hat{y}-y)_i\)表示残差向量的第\(i\)项

我们可以把上式改写为
\[ \frac{\partial CE(z)}{\partial W}=(\hat{y}-y)\cdot x^T \]
同理可得
\[ \frac{\partial CE(z)}{\partial b}=(\hat{y}-y) \]
那么在进行随机梯度下降的时候,更新式就是:
\[ \begin{aligned}&W \leftarrow W - \lambda (\hat{y}-y)\cdot x^T \\&b \leftarrow b - \lambda (\hat{y}-y)\end{aligned}\]
其中\(\lambda\)是学习率

softmax分类器+cross entropy损失函数的求导的更多相关文章

  1. 【转载】softmax的log似然代价函数(求导过程)

    全文转载自:softmax的log似然代价函数(公式求导) 在人工神经网络(ANN)中,Softmax通常被用作输出层的激活函数.这不仅是因为它的效果好,而且因为它使得ANN的输出值更易于理解.同时, ...

  2. softmax、cross entropy和softmax loss学习笔记

    之前做手写数字识别时,接触到softmax网络,知道其是全连接层,但没有搞清楚它的实现方式,今天学习Alexnet网络,又接触到了softmax,果断仔细研究研究,有了softmax,损失函数自然不可 ...

  3. 【机器学习基础】对 softmax 和 cross-entropy 求导

    目录 符号定义 对 softmax 求导 对 cross-entropy 求导 对 softmax 和 cross-entropy 一起求导 References 在论文中看到对 softmax 和 ...

  4. 【机器学习基础】交叉熵(cross entropy)损失函数是凸函数吗?

    之所以会有这个问题,是因为在学习 logistic regression 时,<统计机器学习>一书说它的负对数似然函数是凸函数,而 logistic regression 的负对数似然函数 ...

  5. softmax 损失函数求导过程

    前言:softmax中的求导包含矩阵与向量的求导关系,记录的目的是为了回顾. 下图为利用softmax对样本进行k分类的问题,其损失函数的表达式为结构风险,第二项是模型结构的正则化项. 首先,每个qu ...

  6. softmax交叉熵损失函数求导

    来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...

  7. softmax,softmax loss和cross entropy的区别

     版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/u014380165/article/details/77284921 我们知道卷积神经网络(CNN ...

  8. 关于交叉熵损失函数Cross Entropy Loss

    1.说在前面 最近在学习object detection的论文,又遇到交叉熵.高斯混合模型等之类的知识,发现自己没有搞明白这些概念,也从来没有认真总结归纳过,所以觉得自己应该沉下心,对以前的知识做一个 ...

  9. 【机器学习】BP & softmax求导

    目录 一.BP原理及求导 二.softmax及求导 一.BP 1.为什么沿梯度方向是上升最快方向     根据泰勒公式对f(x)在x0处展开,得到f(x) ~ f(x0) + f'(x0)(x-x0) ...

随机推荐

  1. 【ORM】--FluentNHibernate之AutoMapping详解

           上篇文章详细讨论了FluentNHibernate的基本映射的使用方法,它的映射基本用法是跟NHibernate完全一样的,首先要创建数据库链接配置文件,然后编写Table的Mappin ...

  2. cs107

    基本类型:bool,char,short,int,long,float,double 对于char,short,int,long: 多字节类型赋值给少字节类型,对低字节的细节感兴趣,位模式拷贝. 少字 ...

  3. redis集群同步迁移方法(一):通过redis replication实现

           讲到redis的迁移,一般会使用rdb或者aof在主库做自动重载到目标库方法.但该方法有个问题就是无法保证源节点数据和目标节点数据保持一致,一般线上环境也不允许源库停机,所以要在迁移过程 ...

  4. location.hash详解

    一.#的涵义 #代表网页中的一个位置.其右面的字符,就是该位置的标识符.比如, http://www.example.com/index.html#print 就代表网页index.html的prin ...

  5. shell中的语法(1)

    反引号 命令替换.将命令的输出放在命令行的任意位置. eg. [root@gam ~]# echo The Data is `date` The Data is Fri Nov 18 10:13:56 ...

  6. AJAX实现跨域的三种种方法(代理,JSONP,XHR2)

    由于在工作中需要使用AJAX请求其他域名下的请求,但是会出现拒绝访问的情况,这是因为基于安全的考虑,AJAX只能访问本地的资源,而不能跨域访问. 比如说你的网站域名是aaa.com,想要通过AJAX请 ...

  7. CentOS+Apache+mod_wsgi+Python+Django

    前言 网上有关的教程千篇一律,都是无脑抄,自己都不验证一遍就直接复制,毫无意义,我通过官方文档和自己摸索,总结了一套教程. Django自带Web服务功能,但那只是方便开发调试,生产环境中一般将Dja ...

  8. Linux和windows之间通过scp复制文件

    Windows是不支持ssh协议的 需要安装WinSSHD 安装以及设置过程如下: BvSshServer(原名winsshd)官方下载页在这里:https://www.bitvise.com/dow ...

  9. GridView控件RowDataBound事件中获取列字段途径

    今天不知道怎么回事怎么也找不到gridview列中的控件,关键是其为编辑时隐藏域中的控件,取值就很成问题了,网上搜了很到,找到这个比较经典的东东了,可能大家都知道,但很少对比整理到一起,有多种方法可以 ...

  10. app跳转openURL,兼容方法

    - (void)openScheme:(NSString *)scheme {   UIApplication *application = [UIApplication sharedApplicat ...