二分~多分~Softmax~理预

一、简介

 在二分类问题中,你可以根据神经网络节点的输出,通过一个激活函数如Sigmoid,将其转换为属于某一类的概率,为了给出具体的分类结果,你可以取0.5作为阈值,凡是大于0.5的样本被认为是正类,小于0.5则认为是负类

 然而这样的做法并不容易推广到多分类问题。多分类问题神经网络最常用的方法是根据类别个数n,设置n个输出节点,这样每个样本神经网络都会给出一个n维数组作为输出结果,然后我们运用激活函数如softmax,将输出转换为一种概率分布,其中的每一个概率代表了该样本属于某类的概率。

 比如一个手写数字识别这种简单的10分类问题,对于数字1的识别,神经网络模型的输出结果应该越接近\([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]\)越好,其中\([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]\)是最理想的结果了

 但是如何衡量一个神经网络输出向量和理想的向量的接近程度呢?交叉熵(cross entropy)就是这个评价方法之一,他刻画了两个概率分布之间的距离,是多分类问题中常用的一种损失函数

二、交叉熵

 给定两个概率分布:p(理想结果即正确标签向量)和q(神经网络输出结果即经过softmax转换后的结果向量),则通过q来表示p的交叉熵为:

\(H(p, q) = - \sum_xp(x)logq(x)\)

 注意:既然p和q都是一种概率分布,那么对于任意的x,应该属于\([0, 1]\)并且所有概率和为1

\(\forall x p(X=x) \epsilon [0,1]\)且\(\sum_xp(X=x) =1\)

 交叉熵刻画的是通过概率分布q来表达概率分布p的困难程度,其中p是正确答案,q是预测值,也就是交叉熵值越小,两个概率分布越接近

三、三分类实例讲解交叉熵

 其中某个样本的正确答案即p是\([1,0, 0]\),某模型经过Softmax激活后的答案即预测值q是\([0.5, 0.4, 0.1]\),那么这个预测值和正确答案之间的交叉熵为:

\(H(p=[1,0,0], q=[0.5,0.4,0.1]) = -(1*log0.5 + 0*log0.4 + 0*log0.1) \approx 0.3\)

 如果另外一个模型的预测值q是\([0.8, 0.1, 0.1]\),那么这个预测值和正确答案之间的交叉熵为:

\(H(p=[1,0,0], q=[0.8,0.1,0.1]) = -(1*log0.8 + 0*log0.1 + 0*log0.1) \approx 0.1\)

 从直观上可以很容易的知道第二个预测答案要优于第一个,通过交叉熵计算得到的结果也是一致的(第二个交叉熵值更小)

 而TF中很容易做到交叉熵的计算:

import tensorflow as tf

cross_entropy = -tf.reduce_mean( y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)) )

 上述代码包含了四种不同的TF运算,解释如下:

 1)tf.clip_by_value():将一个张量中的数值限制在一个范围内,如限制在\([0.1, 1.0]\)范围内,可以避免一些运算错误,如预测结果q中元素可能为0,这样的话log0是无效的

v = tf.constant([[1.0, 2.0, 3.0], [4.0,5.0,6.0] ])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(tf.clip_by_value(v, 2.5, 4.5).eval())

  

 2)tf.log():对张量中的所有元素依次求对数

v = tf.constant([[1.0, 2.0, 3.0], [4.0,5.0,6.0] ])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(tf.log(v).eval())

  

 3)乘法运算:*操作,是元素之间直接相乘,而矩阵相乘用tf.matmul函数来完成

v1 = tf.constant([ [1.0, 2.0], [3.0, 4.0]])
v2 = tf.constant([ [5.0, 6.0], [7.0, 8.0]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print( (v1 * v2).eval() )
print( tf.matmul(v1, v2).eval() )

  

 4)求和:上面三个运算完成了每个样例中每个类别的交叉熵\(p(x)logq(x)\)的计算,还未进行求和运算

  即:三步计算后得到的结果是个n * m的二维矩阵,其中n为一个batch中样本数量,m为分类的类别数量,比如十分类问题,m为10.根据交叉熵公式

  最后是要将每行中m个结果相加得到每个样本的交叉熵,然后再对这n行取平均得到一个batch的平均交叉熵,即-tf.reduce_mean()函数来实现

 总结:因为交叉熵一般会与softmax一起使用,所以TF对这两个功能进行了封装,并提供了tf.nn.softmax_cross_entropy_with_logits函数

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y, y_)
#y:原始神经网络的输出结果
#y_:标准答案
#这样一个函数即可实现使用了Softmax后的交叉熵

 关于Softmax计算过程,可以参考:实战Google深度学习框架-C4-深层神经网络

关于交叉熵(cross entropy),你了解哪些的更多相关文章

  1. 最大似然估计 (Maximum Likelihood Estimation), 交叉熵 (Cross Entropy) 与深度神经网络

    最近在看深度学习的"花书" (也就是Ian Goodfellow那本了),第五章机器学习基础部分的解释很精华,对比PRML少了很多复杂的推理,比较适合闲暇的时候翻开看看.今天准备写 ...

  2. 交叉熵cross entropy和相对熵(kl散度)

    交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量真实分布p与当前训练得到的概率分布q有多么大的差异. 相对熵(relativ ...

  3. 深度学习中交叉熵和KL散度和最大似然估计之间的关系

    机器学习的面试题中经常会被问到交叉熵(cross entropy)和最大似然估计(MLE)或者KL散度有什么关系,查了一些资料发现优化这3个东西其实是等价的. 熵和交叉熵 提到交叉熵就需要了解下信息论 ...

  4. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  5. TensorFlow 实战(一)—— 交叉熵(cross entropy)的定义

    对多分类问题(multi-class),通常使用 cross-entropy 作为 loss function.cross entropy 最早是信息论(information theory)中的概念 ...

  6. 【机器学习基础】交叉熵(cross entropy)损失函数是凸函数吗?

    之所以会有这个问题,是因为在学习 logistic regression 时,<统计机器学习>一书说它的负对数似然函数是凸函数,而 logistic regression 的负对数似然函数 ...

  7. 关于交叉熵损失函数Cross Entropy Loss

    1.说在前面 最近在学习object detection的论文,又遇到交叉熵.高斯混合模型等之类的知识,发现自己没有搞明白这些概念,也从来没有认真总结归纳过,所以觉得自己应该沉下心,对以前的知识做一个 ...

  8. 【联系】二项分布的对数似然函数与交叉熵(cross entropy)损失函数

    1. 二项分布 二项分布也叫 0-1 分布,如随机变量 x 服从二项分布,关于参数 μ(0≤μ≤1),其值取 1 和取 0 的概率如下: {p(x=1|μ)=μp(x=0|μ)=1−μ 则在 x 上的 ...

  9. 熵(Entropy),交叉熵(Cross-Entropy),KL-松散度(KL Divergence)

    1.介绍: 当我们开发一个分类模型的时候,我们的目标是把输入映射到预测的概率上,当我们训练模型的时候就不停地调整参数使得我们预测出来的概率和真是的概率更加接近. 这篇文章我们关注在我们的模型假设这些类 ...

随机推荐

  1. Power Spectral Density

    对于一个特定的信号来说,有时域与频域两个表达形式,时域表现的是信号随时间的变化,频域表现的是信号在不同频率上的分量.在信号处理中,通常会对信号进行傅里叶变换得到该信号的频域表示,从而得到信号在频域上的 ...

  2. Promise.all和Promise.race区别,和使用场景

    一.Pomise.all的使用 常见使用场景 : 多个异步结果合并到一起 Promise.all可以将多个Promise实例包装成一个新的Promise实例.用于将多个Promise实例,包装成一个新 ...

  3. 在一台服务器上配置多个Tomcat的方法

    原文来自:http://blog.csdn.net/lmb55/article/details/49561669 这段时间在开发智能导航的热部署功能,需要从一台服务器去访问其它的24台服务器去进行相关 ...

  4. Codeforces Round #467 Div. 1

    B:显然即相当于能否找一条有长度为奇数的路径使得终点出度为0.如果没有环直接dp即可.有环的话可以考虑死了的spfa,由于每个点我们至多只需要让其入队两次,复杂度变成了优秀的O(kE).事实上就是拆点 ...

  5. Codeforces1065F Up and Down the Tree 【树形DP】

    推荐一道联赛练习题. 题目分析: 你考虑进入一个子树就可能上不来了,如果上得来的话就把能上来的全捡完然后走一个上不来的,所以这就是个基本的DP套路. 代码: #include<bits/stdc ...

  6. IDEA中Maven项目使用Junit4单元测试的写法

    IDEA默认是安装了junit控件的,直接使用就好了 在maven项目的pom.xml文件中添加依赖 <dependency> <groupId>junit</group ...

  7. python学习日记(函数进阶)

    命名空间 内置命名空间 存放了python解释器为我们提供的名字:print,input...等等,他们都是我们熟悉的,拿过来就可以用的方法. 内置的名字在启动解释器(程序运行前)的时候被加载在内存里 ...

  8. 图论杂项细节梳理&模板(虚树,圆方树,仙人掌,欧拉路径,还有。。。)

    orzYCB 虚树 %自为风月马前卒巨佬% 用于优化一类树形DP问题. 当状态转移只和树中的某些关键点有关的时候,我们把这些点和它们两两之间的LCA弄出来,以点的祖孙关系连成一棵新的树,这就是虚树. ...

  9. SCOI 2015 Day2 简要题解

    「SCOI2015」小凸玩密室 题意 小凸和小方相约玩密室逃脱,这个密室是一棵有 $ n $ 个节点的完全二叉树,每个节点有一个灯泡.点亮所有灯泡即可逃出密室.每个灯泡有个权值 $ A_i $,每条边 ...

  10. bit、Byte、bps、Bps、pps、Gbps的单位详细说明及换算

    1. bit 电脑记忆体中最小的单位,在二进位电脑系统中,每1bit 可以代表0 或 1 的数位讯号. 2. Byte 字节单位,一般表示存储介质大小的单位,一个B(常用大写的B来表示Byte)可代表 ...