系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI
点击star加星不要吝啬,星越多笔者越努力。

3.2 交叉熵损失函数

交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。在信息论中,交叉熵是表示两个概率分布 \(p,q\) 的差异,其中 \(p\) 表示真实分布,\(q\) 表示非真实分布,那么\(H(p,q)\)就称为交叉熵:

\[H(p,q)=\sum_i p_i \cdot \ln {1 \over q_i} = - \sum_i p_i \ln q_i \tag{1}\]

交叉熵可在神经网络中作为损失函数,\(p\) 表示真实标记的分布,\(q\) 则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量 \(p\) 与 \(q\) 的相似性。

交叉熵函数常用于逻辑回归(logistic regression),也就是分类(classification)。

3.2.1 交叉熵的由来

信息量

信息论中,信息量的表示方式:

\[I(x_j) = -\ln (p(x_j)) \tag{2}\]

\(x_j\):表示一个事件

\(p(x_j)\):表示\(x_j\)发生的概率

\(I(x_j)\):信息量,\(x_j\)越不可能发生时,它一旦发生后的信息量就越大

假设对于学习神经网络原理课程,我们有三种可能的情况发生,如表3-2所示。

表3-2 三种事件的概论和信息量

事件编号 事件 概率 \(p\) 信息量 \(I\)
\(x_1\) 优秀 \(p=0.7\) \(I=-\ln(0.7)=0.36\)
\(x_2\) 及格 \(p=0.2\) \(I=-\ln(0.2)=1.61\)
\(x_3\) 不及格 \(p=0.1\) \(I=-\ln(0.1)=2.30\)

WoW,某某同学不及格!好大的信息量!相比较来说,“优秀”事件的信息量反而小了很多。

\[H(p) = - \sum_j^n p(x_j) \ln (p(x_j)) \tag{3}\]

则上面的问题的熵是:

\[
\begin{aligned}
H(p)&=-[p(x_1) \ln p(x_1) + p(x_2) \ln p(x_2) + p(x_3) \ln p(x_3)] \\
&=0.7 \times 0.36 + 0.2 \times 1.61 + 0.1 \times 2.30 \\
&=0.804
\end{aligned}
\]

相对熵(KL散度)

相对熵又称KL散度,如果我们对于同一个随机变量 \(x\) 有两个单独的概率分布 \(P(x)\) 和 \(Q(x)\),我们可以使用 KL 散度(Kullback-Leibler (KL) divergence)来衡量这两个分布的差异,这个相当于信息论范畴的均方差。

KL散度的计算公式:

\[D_{KL}(p||q)=\sum_{j=1}^n p(x_j) \ln{p(x_j) \over q(x_j)} \tag{4}\]

\(n\) 为事件的所有可能性。\(D\) 的值越小,表示 \(q\) 分布和 \(p\) 分布越接近。

交叉熵

把上述公式变形:

\[
\begin{aligned}
D_{KL}(p||q)&=\sum_{j=1}^n p(x_j) \ln{p(x_j)} - \sum_{j=1}^n p(x_j) \ln q(x_j) \\
&=- H(p(x)) + H(p,q)
\end{aligned}
\tag{5}
\]

等式的前一部分恰巧就是p的熵,等式的后一部分,就是交叉熵:

\[H(p,q) =- \sum_{j=1}^n p(x_j) \ln q(x_j) \tag{6}\]

在机器学习中,我们需要评估label和predicts之间的差距,使用KL散度刚刚好,即\(D_{KL}(y||a)\),由于KL散度中的前一部分\(H(y)\)不变,故在优化过程中,只需要关注交叉熵就可以了。所以一般在机器学习中直接用交叉熵做损失函数来评估模型。

\[loss =- \sum_{j=1}^n y_j \ln a_j \tag{7}\]

其中,\(n\) 并不是样本个数,而是分类个数。所以,对于批量样本的交叉熵计算公式是:

\[J =- \sum_{i=1}^m \sum_{j=1}^n y_{ij} \ln a_{ij} \tag{8}\]

\(m\) 是样本数,\(n\) 是分类数。

有一类特殊问题,就是事件只有两种情况发生的可能,比如“学会了”和“没学会”,称为\(0/1\)分布或二分类。对于这类问题,由于\(n=2\),所以交叉熵可以简化为:

\[loss =-[y \ln a + (1-y) \ln (1-a)] \tag{9}\]

二分类对于批量样本的交叉熵计算公式是:

\[J= - \sum_{i=1}^m [y_i \ln a_i + (1-y_i) \ln (1-a_i)] \tag{10}\]

3.2.2 二分类问题交叉熵

把公式10分解开两种情况,当\(y=1\)时,即标签值是1,是个正例,加号后面的项为0:

\[loss = -\ln(a) \tag{11}\]

横坐标是预测输出,纵坐标是损失函数值。y=1意味着当前样本标签值是1,当预测输出越接近1时,损失函数值越小,训练结果越准确。当预测输出越接近0时,损失函数值越大,训练结果越糟糕。

当y=0时,即标签值是0,是个反例,加号前面的项为0:

\[loss = -\ln (1-a) \tag{12}\]

此时,损失函数值如图3-10。

图3-10 二分类交叉熵损失函数图

假设学会了课程的标签值为1,没有学会的标签值为0。我们想建立一个预测器,对于一个特定的学员,根据出勤率、课堂表现、作业情况、学习能力等等来预测其学会课程的概率。

对于学员甲,预测其学会的概率为0.6,而实际上该学员通过了考试,真实值为1。所以,学员甲的交叉熵损失函数值是:

\[
loss_1 = -(1 \times \ln 0.6 + (1-1) \times \ln (1-0.6)) = 0.51
\]

对于学员乙,预测其学会的概率为0.7,而实际上该学员也通过了考试。所以,学员乙的交叉熵损失函数值是:

\[
loss_2 = -(1 \times \ln 0.7 + (1-1) \times \ln (1-0.7)) = 0.36
\]

由于0.7比0.6更接近1,是相对准确的值,所以 \(loss2\) 要比 \(loss1\) 小,反向传播的力度也会小。

3.2.3 多分类问题交叉熵

当标签值不是非0即1的情况时,就是多分类了。假设期末考试有三种情况:

  1. 优秀,标签值OneHot编码为\([1,0,0]\)
  2. 及格,标签值OneHot编码为\([0,1,0]\)
  3. 不及格,标签值OneHot编码为\([0,0,1]\)

假设我们预测学员丙的成绩为优秀、及格、不及格的概率为:\([0.2,0.5,0.3]\),而真实情况是该学员不及格,则得到的交叉熵是:

\[
loss_1 = -(0 \times \ln 0.2 + 0 \times \ln 0.5 + 1 \times \ln 0.3) = 1.2
\]

假设我们预测学员丁的成绩为优秀、及格、不及格的概率为:\([0.2,0.2,0.6]\),而真实情况是该学员不及格,则得到的交叉熵是:

\[
loss_2 = -(0 \times \ln 0.2 + 0 \times \ln 0.2 + 1 \times \ln 0.6) = 0.51
\]

可以看到,0.51比1.2的损失值小很多,这说明预测值越接近真实标签值(0.6 vs 0.3),交叉熵损失函数值越小,反向传播的力度越小。

3.2.4 为什么不能使用均方差做为分类问题的损失函数?

  1. 回归问题通常用均方差损失函数,可以保证损失函数是个凸函数,即可以得到最优解。而分类问题如果用均方差的话,损失函数的表现不是凸函数,就很难得到最优解。而交叉熵函数可以保证区间内单调。

  2. 分类问题的最后一层网络,需要分类函数,Sigmoid或者Softmax,如果再接均方差函数的话,其求导结果复杂,运算量比较大。用交叉熵函数的话,可以得到比较简单的计算结果,一个简单的减法就可以得到反向误差。

[ch03-02] 交叉熵损失函数的更多相关文章

  1. 【转载】深度学习中softmax交叉熵损失函数的理解

    深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning   版权声明:本文为博主原 ...

  2. 深度学习原理与框架-神经网络结构与原理 1.得分函数 2.SVM损失函数 3.正则化惩罚项 4.softmax交叉熵损失函数 5. 最优化问题(前向传播) 6.batch_size(批量更新权重参数) 7.反向传播

    神经网络由各个部分组成 1.得分函数:在进行输出时,对于每一个类别都会输入一个得分值,使用这些得分值可以用来构造出每一个类别的概率值,也可以使用softmax构造类别的概率值,从而构造出loss值, ...

  3. 关于交叉熵损失函数Cross Entropy Loss

    1.说在前面 最近在学习object detection的论文,又遇到交叉熵.高斯混合模型等之类的知识,发现自己没有搞明白这些概念,也从来没有认真总结归纳过,所以觉得自己应该沉下心,对以前的知识做一个 ...

  4. softmax交叉熵损失函数求导

    来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...

  5. 吴裕雄--天生自然 pythonTensorFlow自然语言处理:交叉熵损失函数

    import tensorflow as tf # 1. sparse_softmax_cross_entropy_with_logits样例. # 假设词汇表的大小为3, 语料包含两个单词" ...

  6. BCE和CE交叉熵损失函数的区别

    首先需要说明的是PyTorch里面的BCELoss和CrossEntropyLoss都是交叉熵,数学本质上是没有区别的,区别在于应用中的细节. BCE适用于0/1二分类,计算公式就是 " - ...

  7. 简单易懂的softmax交叉熵损失函数求导

    参考: https://blog.csdn.net/qian99/article/details/78046329

  8. 交叉熵损失函数,以及pytorch CrossEntropyLoss的理解

    实际运用例子: https://zhuanlan.zhihu.com/p/35709485 pytorch CrossEntropyLoss,参考博客如下: https://mathpretty.co ...

  9. 【联系】二项分布的对数似然函数与交叉熵(cross entropy)损失函数

    1. 二项分布 二项分布也叫 0-1 分布,如随机变量 x 服从二项分布,关于参数 μ(0≤μ≤1),其值取 1 和取 0 的概率如下: {p(x=1|μ)=μp(x=0|μ)=1−μ 则在 x 上的 ...

随机推荐

  1. JVM(7) Java内存模型与线程

    衡量一个服务性能的高低好坏,每秒事务处理数(Transactions Per Second,TPS)是最重要的指标之一,它代表着一秒内服务端平均能响应的请求总数,而 TPS 值与程序的并发能力又有非常 ...

  2. Python+Keras+TensorFlow车牌识别

    这个是我使用的车牌识别开源项目的地址:https://github.com/zeusees/HyperLPR Python 依赖 Anaconda for Python 3.x on Win64 Ke ...

  3. 【java基础之异常】死了都要try,不淋漓尽致地catch我不痛快!

    目录 1.异常 1.1 异常概念 1.2 异常体系 1.3 异常分类 1.4 异常的产生过程解析 2. 异常的处理 2.1 抛出异常throw 2.2 Objects非空判断 2.3 声明异常thro ...

  4. 讲一讲快速学习WPF的思路

    我不想浪费大家的时间,直接奔主题了. 首先大家要明白,WPF跟Winform的区别,优点,缺点. 首先入门来讲 Winform简单点,WPF会难一点.所以第一次接触C# 我推荐用Winform项目去学 ...

  5. AQS 入门

    一 AQS简介 路径:java.util.concurrent.locks.AbstractOwnableSynchronizer. 定义:AQS提供了一种 通过维护一个volatile修饰 int类 ...

  6. linux sudo root 权限绕过漏洞(CVE-2019-14287)

    0x01 逛圈子社区论坛 看到了 linux sudo root 权限绕过漏洞(CVE-2019-14287) 跟着复现下 综合来说 这个漏洞作用不大  需要以下几个前提条件 1.知道当前普通用户的密 ...

  7. Spring Cloud Gateway使用简介

    Spring Cloud Gateway是类似Nginx的网关路由代理,有替代原来Spring cloud zuul之意: Spring 5 推出了自己的Spring Cloud Gateway,支持 ...

  8. 用css和js实现侧边菜单栏点击和鼠标滑动特效

    1点击效果: 2关键代码: css: #div{ display:inline-block; width:100px; height:150px; border-radius: 5px; blackg ...

  9. 学习笔记之vim的使用

    很多刚学习linux编程的人总是对vim有一种恐惧,我自己就是这么回事的. 可是当你努力的去尝试学习使用后,才发现它的精髓所在. 在我看来,让vim变得好用的前提是要安装两个插件,ctags和tagl ...

  10. Cpython和Jython的对比介绍

    CPython 当我们从Python官方网站下载并安装好Python 3.x后,我们就直接获得了一个官方版本的解释器:CPython.这个解释器是用C语言开发的,所以叫CPython.在命令行下运行p ...