如何训练一个 Softmax 分类器

回忆一下之前举的的例子,输出层计算出的\(z^{[l]}\)如下,\(z^{[l]} = \begin{bmatrix} 5 \\ 2 \\ - 1 \\ 3 \\ \end{bmatrix}\)有四个分类\(C=4\),\(z^{[l]}\)可以是4×1维向量,计算了临时变量\(t\),\(t = \begin{bmatrix} e^{5} \\ e^{2} \\ e^{- 1} \\ e^{3} \\ \end{bmatrix}\),对元素进行幂运算,最后,如果的输出层的激活函数\(g^{[L]}()\)是Softmax激活函数,那么输出就会是这样的:

简单来说就是用临时变量\(t\)将它归一化,使总和为1,于是这就变成了\(a^{[L]}\),注意到向量\(z\)中,最大的元素是5,而最大的概率也就是第一种概率。

Softmax这个名称的来源是与所谓hardmax对比,hardmax会把向量\(z\)变成这个向量\(\begin{bmatrix} 1 \\ 0 \\ 0 \\ 0 \\ \end{bmatrix}\),hardmax函数会观察\(z\)的元素,然后在\(z\)中最大元素的位置放上1,其它位置放上0,所这是一个hard max,也就是最大的元素的输出为1,其它的输出都为0。与之相反,Softmax所做的从\(z\)到这些概率的映射更为温和,不知道这是不是一个好名字,但至少这就是softmax这一名称背后所包含的想法,与hardmax正好相反。

有一点没有细讲,但之前已经提到过的,就是Softmax回归或Softmax激活函数将logistic激活函数推广到\(C\)类,而不仅仅是两类,结果就是如果\(C=2\),那么\(C=2\)的Softmax实际上变回了logistic回归,因篇幅这里不会给出证明,但是大致的证明思路是这样的,如果\(C=2\),并且应用了Softmax,那么输出层\(a^{[L]}\)将会输出两个数字,如果\(C=2\)的话,也许输出0.842和0.158,对吧?这两个数字加起来要等于1,因为它们的和必须为1,其实它们是冗余的,也许不需要计算两个,而只需要计算其中一个,结果就是最终计算那个数字的方式又回到了logistic回归计算单个输出的方式。这算不上是一个证明,但可以从中得出结论,Softmax回归将logistic回归推广到了两种分类以上。

接下来来看怎样训练带有Softmax输出层的神经网络,具体而言,先定义训练神经网络使会用到的损失函数。举个例子,来看看训练集中某个样本的目标输出,真实标签是\(\begin{bmatrix} 0 \\ 1 \\ 0 \\ 0 \\ \end{bmatrix}\),用上个例子,这表示这是一张猫的图片,因为它属于类1,现在假设的神经网络输出的是\(\hat y\),\(\hat y\)是一个包括总和为1的概率的向量,\(y = \begin{bmatrix} 0.3 \\ 0.2 \\ 0.1 \\ 0.4 \\ \end{bmatrix}\),可以看到总和为1,这就是\(a^{[l]}\),\(a^{[l]} = y = \begin{bmatrix} 0.3 \\ 0.2 \\ 0.1 \\ 0.4 \\ \end{bmatrix}\)。对于这个样本神经网络的表现不佳,这实际上是一只猫,但却只分配到20%是猫的概率,所以在本例中表现不佳。

那么想用什么损失函数来训练这个神经网络?在Softmax分类中,一般用到的损失函数是\(L(\hat y,y ) = - \sum_{j = 1}^{4}{y_{j}log\hat y_{j}}\),来看上面的单个样本来更好地理解整个过程。注意在这个样本中\(y_{1} =y_{3} = y_{4} = 0\),因为这些都是0,只有\(y_{2} =1\),如果看这个求和,所有含有值为0的\(y_{j}\)的项都等于0,最后只剩下\(-y_{2}t{log}\hat y_{2}\),因为当按照下标\(j\)全部加起来,所有的项都为0,除了\(j=2\)时,又因为\(y_{2}=1\),所以它就等于\(- \ log\hat y_{2}\)。

\(L\left( \hat y,y \right) = - \sum_{j = 1}^{4}{y_{j}\log \hat y_{j}} = - y_{2}{\ log} \hat y_{2} = - {\ log} \hat y_{2}\)

这就意味着,如果的学习算法试图将它变小,因为梯度下降法是用来减少训练集的损失的,要使它变小的唯一方式就是使\(-{\log}\hat y_{2}\)变小,要想做到这一点,就需要使\(\hat y_{2}\)尽可能大,因为这些是概率,所以不可能比1大,但这的确也讲得通,因为在这个例子中\(x\)是猫的图片,就需要这项输出的概率尽可能地大(\(y= \begin{bmatrix} 0.3 \\ 0.2 \\ 0.1 \\ 0.4 \\ \end{bmatrix}\)中第二个元素)。

概括来讲,损失函数所做的就是它找到的训练集中的真实类别,然后试图使该类别相应的概率尽可能地高,如果熟悉统计学中最大似然估计,这其实就是最大似然估计的一种形式。但如果不知道那是什么意思,也不用担心,用刚刚讲过的算法思维也足够了。

这是单个训练样本的损失,整个训练集的损失\(J\)又如何呢?也就是设定参数的代价之类的,还有各种形式的偏差的代价,它的定义大致也能猜到,就是整个训练集损失的总和,把的训练算法对所有训练样本的预测都加起来,

\(J( w^{[1]},b^{[1]},\ldots\ldots) = \frac{1}{m}\sum_{i = 1}^{m}{L( \hat y^{(i)},y^{(i)})}\)

因此要做的就是用梯度下降法,使这里的损失最小化。

最后还有一个实现细节,注意因为\(C=4\),\(y\)是一个4×1向量,\(y\)也是一个4×1向量,如果实现向量化,矩阵大写\(Y\)就是\(\lbrack y^{(1)}\text{}y^{(2)}\ldots\ldots\ y^{\left( m \right)}\rbrack\),例如如果上面这个样本是的第一个训练样本,那么矩阵\(Y =\begin{bmatrix} 0 & 0 & 1 & \ldots \\ 1 & 0 & 0 & \ldots \\ 0 & 1 & 0 & \ldots \\ 0 & 0 & 0 & \ldots \\ \end{bmatrix}\),那么这个矩阵\(Y\)最终就是一个\(4×m\)维矩阵。类似的,\(\hat{Y} = \lbrack{\hat{y}}^{(1)}{\hat{y}}^{(2)} \ldots \ldots\ {\hat{y}}^{(m)}\rbrack\),这个其实就是\({\hat{y}}^{(1)}\)(\(a^{[l](1)} = y^{(1)} = \begin{bmatrix} 0.3 \\ 0.2 \\ 0.1 \\ 0.4 \\ \end{bmatrix}\)),或是第一个训练样本的输出,那么\(\hat{Y} = \begin{bmatrix} 0.3 & \ldots \\ 0.2 & \ldots \\ 0.1 & \ldots \\ 0.4 & \ldots \\ \end{bmatrix}\),\(\hat{Y}\)本身也是一个\(4×m\)维矩阵。



最后来看一下,在有Softmax输出层时如何实现梯度下降法,这个输出层会计算\(z^{[l]}\),它是\(C×1\)维的,在这个例子中是4×1,然后用Softmax激活函数来得到\(a^{[l]}\)或者说\(y\),然后又能由此计算出损失。已经讲了如何实现神经网络前向传播的步骤,来得到这些输出,并计算损失,那么反向传播步骤或者梯度下降法又如何呢?其实初始化反向传播所需要的关键步骤或者说关键方程是这个表达式\(dz^{[l]} = \hat{y} -y\),可以用\(\hat{y}\)这个4×1向量减去\(y\)这个4×1向量,可以看到这些都会是4×1向量,当有4个分类时,在一般情况下就是\(C×1\),这符合对\(dz\)的一般定义,这是对\(z^{[l]}\)损失函数的偏导数(\(dz^{[l]} = \frac{\partial J}{\partial z^{[l]}}\)),如果精通微积分就可以自己推导,或者说如果精通微积分,可以试着自己推导,但如果需要从零开始使用这个公式,它也一样有用。

有了这个,就可以计算\(dz^{[l]}\),然后开始反向传播的过程,计算整个神经网络中所需要的所有导数。

开始使用一种深度学习编程框架,对于这些编程框架,通常只需要专注于把前向传播做对,只要将它指明为编程框架,前向传播,它自己会弄明白怎样反向传播,会帮实现反向传播,所以这个表达式值得牢记(\(dz^{[l]} = \hat{y} -y\))。

神经网络优化篇:详解如何训练一个 Softmax 分类器(Training a Softmax classifier)的更多相关文章

  1. 走向DBA[MSSQL篇] 详解游标

    原文:走向DBA[MSSQL篇] 详解游标 前篇回顾:上一篇虫子介绍了一些不常用的数据过滤方式,本篇详细介绍下游标. 概念 简单点说游标的作用就是存储一个结果集,并根据语法将这个结果集的数据逐条处理. ...

  2. Scala进阶之路-Scala函数篇详解

    Scala进阶之路-Scala函数篇详解 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.传值调用和传名调用 /* @author :yinzhengjie Blog:http: ...

  3. PHP函数篇详解十进制、二进制、八进制和十六进制转换函数说明

    PHP函数篇详解十进制.二进制.八进制和十六进制转换函数说明 作者: 字体:[增加 减小] 类型:转载   中文字符编码研究系列第一期,PHP函数篇详解十进制.二进制.八进制和十六进制互相转换函数说明 ...

  4. Cookie 详解以及实现一个 cookie 操作库

    Cookie 详解以及实现一个 cookie 操作库 cookie 在前端有着大量的应用,但有时我们对它还是一知半解.下面来看看它的一些具体的用法 Set-Cookie 服务器通过设置响应头来设置客户 ...

  5. 深度学习基础(CNN详解以及训练过程1)

    深度学习是一个框架,包含多个重要算法: Convolutional Neural Networks(CNN)卷积神经网络 AutoEncoder自动编码器 Sparse Coding稀疏编码 Rest ...

  6. 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)

    本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch  (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...

  7. 2、Spring的 IoC详解(第一个Spring程序)

    Spring是为了解决企业应用开发的复杂性而创建的一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架.在这句话中重点有两个,一个是IoC,另一个是AOP.今天我们讲第一个IoC. IoC概念 ...

  8. Canal:同步mysql增量数据工具,一篇详解核心知识点

    老刘是一名即将找工作的研二学生,写博客一方面是总结大数据开发的知识点,一方面是希望能够帮助伙伴让自学从此不求人.由于老刘是自学大数据开发,博客中肯定会存在一些不足,还希望大家能够批评指正,让我们一起进 ...

  9. KestrelServer详解[3]: 自定义一个迷你版的KestrelServer

    和所有的服务器一样,KestrelServer最终需要解决的是网络传输的问题.在<网络连接的创建>,我们介绍了KestrelServer如何利用连接接听器的建立网络连接,并再次基础上演示了 ...

  10. java提高篇-----详解java的四舍五入与保留位

    转载:http://blog.csdn.net/chenssy/article/details/12719811 四舍五入是我们小学的数学问题,这个问题对于我们程序猿来说就类似于1到10的加减乘除那么 ...

随机推荐

  1. ECharts大屏数据可视化展板项目 适配rem

    1.在utils文件夹里新建一个rem.js 2.main.js中引入rem.js 3.vscode中下载cssrem插件,配置Root Font Size大小,为1920/20 = 96. 重启vs ...

  2. ES5新增语法

    ES5中给我们新增了一些方法,可以很方便的操作数组或者字符串,这些方法主要包括:数组方法.字符串方法.对象方法. 1. 数组方法 迭代(遍历)方法:forEach() .map().filter(). ...

  3. 2023全国大学生电子设计竞赛H题全解 [原创www.cnblogs.com/helesheng]

    2023年又是全国大学生电子设计竞赛年,一如既往的指导学生死磕H题.8月2日看到公布的赛题,我自己还沾沾自喜,觉得今年学生用嵌入式系统和数字信号处理知识就可以完成这题,赛前都辅导过,应该成绩不差.哪想 ...

  4. 介绍几种OPTIONS检测的方法

    概述 日常的VOIP开发中,OPTIONS检测是常用的网络状态检测工具. OPTIONS原本是作为获取对方能力的消息,也可以检测当前服务状态.正常情况下,UAS收到OPTIONS心跳,直接回复200即 ...

  5. 玩转 Helm 之 upgrade

    0. 前言 在 玩转 Helm 一文中,简略提到了 Helm upgrade 的策略. 在实际项目开发上,upgrade 多是调研的重点.基于此,这里对 upgrade 继续展开. 1. basic ...

  6. idea 查看类的继承结构及其子类

    转载请注明出处: 在idea中通过查看一个类或接口的继承结构,可以了解到整个相关功能设计的流程 idea中查看一个类或接口的继承结构的方法如下: 1.选中一个类:右键进入继承结构视图: 效果图如下:

  7. .NET技术面试题系列(2) -sql server数据库优化规范

    1.数据库优化规范 a.索引 每个表格都要求建立主键,主键上不一定需要强制建立聚集索引. 聚集索引,表中存储的数据按照索引的顺序存储,即逻辑顺序决定了表中相应行的物理顺序,因此聚集索引的字段值应是不会 ...

  8. [转帖]一份完整的阿里云 Redis 开发规范,值得收藏!

    https://blog.csdn.net/NicolasLearner/article/details/117449847 作者:付磊-起扬 http://yq.aliyun.com/article ...

  9. [转帖]3--二进制安装k8s

    https://www.cnblogs.com/caodan01/p/15104491.html 目录 一.节点规划 二.插件规划 三.系统优化(所有master节点) 1.关闭swap分区 2.关闭 ...

  10. [转帖]【我和CloudQuery 的故事】安装部署CloudQuery 初体验—-前篇

    https://www.modb.pro/db/1694256553947910144 一.前言 在日常数据库运维中,为连接多种数据库,经常要安装不同的客户端,非常繁琐,且占用大量存储空间.如果能有一 ...