关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类、多分类等)任务目标,可以参考文章keras中两种交叉熵损失函数的探讨,其结合keras的API讨论了两者的计算原理和应用原理。

本文主要是介绍TF中的接口调用方式。

一、二分类交叉熵

对应的是网络输出单个节点,这个节点将被sigmoid处理,使用阈值分类为0或者1的问题。此类问题logits和labels必须具有相同的type和shape

原理介绍

x = logits, z = labels.
logistic loss 计算式为:
其中交叉熵(cross entripy)基本函数式

z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
    = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
    = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
    = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
    = (1 - z) * x + log(1 + exp(-x))
    = x - x * z + log(1 + exp(-x))

对于x<0时,为了避免计算exp(-x)时溢出,我们使用以下这种形式表示

x - x * z + log(1 + exp(-x))
    = log(exp(x)) - x * z + log(1 + exp(-x))
    = - x * z + log(1 + exp(x))

综合x>0和x<0的情况,并防止溢出我们使用如下公式,

max(x, 0) - x *z + log(1 + exp(-abs(x)))

接口介绍

import numpy as np
import tensorflow as tf input_data = tf.Variable(np.random.rand(1, 3), dtype=tf.float32)
# np.random.rand()传入一个shape,返回一个在[0,1)区间符合均匀分布的array output = tf.nn.sigmoid_cross_entropy_with_logits(logits=input_data, labels=[[1.0, 0.0, 0.0]])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))
# [[ 0.5583781 1.06925142 1.08170223]]

二、多分类交叉熵

对应的是网络输出多个节点,每个节点表示1个class的得分,使用Softmax最终处理的分类问题。

原理介绍

cross_entropy = -tf.reduce_mean(y * tf.log(tf.clip_by_value(y_pre, 1e-10, 1.0))

调用一下:

import tensorflow as tf

input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
labels=tf.constant([[1,0,0], [0,1,0]], dtype=tf.float32) cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(input_data, 1e-10, 1.0))) with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))

接口介绍

softmax之后,计算输出层全部节点各自的交叉熵(输出向量而非标量)

cross_entropy_mean = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.argmax(labels,1), logits=logits), name='cross_entropy') cross_entropy_mean = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels), name='cross_entropy')

tf.nn.softmax_cross_entropy_with_logits()

函数的参数label是稀疏表示的,比如表示一个3分类的一个样本的标签,稀疏表示的形式为[0,0,1]这个表示这个样本为第3个分类,而非稀疏表示就表示为2,同理[0,1,0]就表示样本属于第2个分类,而其非稀疏表示为1。

import tensorflow as tf

input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
output = tf.nn.softmax_cross_entropy_with_logits(logits=input_data, labels=[[1,0,0],
[0,1,0]])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))

tf.nn.sparse_softmax_cross_entropy_with_logits()

此函数大致与tf.nn.softmax_cross_entropy_with_logits的计算方式相同,
适用于每个类别相互独立且排斥的情况,一幅图只能属于一类,而不能同时包含一条狗和一只大象

但是在对于labels的处理上有不同之处,labels从shape来说此函数要求shape为[batch_size],
labels[i]是[0,num_classes)的一个索引, type为int32或int64,即labels限定了是一个一阶tensor,
并且取值范围只能在分类数之内,表示一个对象只能属于一个类别

import tensorflow as tf

input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
output = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=input_data, labels=[0, 2])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))
# [ 1.36573195 0.93983102]

比tf.nn.softmax_cross_entropy_with_logits多了一步将labels稀疏化的操作。因为深度学习中,图片一般是用非稀疏的标签的,所以tf.nn.sparse_softmax_cross_entropy_with_logits()的频率比tf.nn.softmax_cross_entropy_with_logits高。

不过两者输出尺寸等于输入shape去掉最后一维(上面输入[2*3],输出[2]),所以均常和tf.reduce_mean()连用。

『TensorFlow』分类问题与两种交叉熵的更多相关文章

  1. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  2. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  3. 『TensorFlow』分布式训练_其三_多机分布式

    本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...

  4. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  5. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  6. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理

    Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...

  9. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

随机推荐

  1. 模拟器中安装APK报Error:INSTALL_FAILED_NO_MATCHING_ABIS

    1.启动AVD Manager.exe 2.将APP的安装包.apk直接拖到模拟器中,报错 3.原来是代码里由于大小限制只开放了armeabi-v7a这个ABI,但创建的模拟机支持的CPU是其他类型的 ...

  2. Jemter 压测基础(一)——基本概念、JMeter安装使用、分布式测试、导出测试结果、编写测试报告

    Jemter   压测基础(一) 1.压力测试的基本概念: 1.吞吐率(Requestspersecond) 服务器并发处理能力的量化描述,单位是reqs/s,指的是某个并发用户数下单位时间内处理的请 ...

  3. Pycharm调试:进入调用函数后返回

    在菜单栏的view中勾选toolbar,然后点击工具栏中左箭头返回到调用函数处.

  4. vue 中使用 axios 请求接口,请求会发送两次问题

    在开发项目过程中,发现在使用axios调用接口都会有两个请求,第一个请求时,看不到请求参数,也看不到请求的结果:只有第二次请求时才会有相应的请求参数以及请求结果: 那为甚么会有这么一次额外的请求呢,后 ...

  5. Failed to load driver class com.mysql.jdbc.Driver from HikariConfig class classloader sun.misc.Launcher$AppClassLoader@18b4aac2

    1.错误日志 Failed to load driver class com.mysql.jdbc.Driver from HikariConfig class classloader sun.mis ...

  6. Cocos Creator 鼠标事件

    鼠标事件// 使用枚举类型来注册node.on(cc.Node.EventType.MOUSE_DOWN, function (event) {console.log('Mouse down');}, ...

  7. Unable to register MBean [HikariDataSource (HikariPool-0)] with key 'dataSou rce'; nested exception is javax.management.InstanceAlreadyExistsException: com.z axxer.hikari:name=dataSource,type=HikariDa

    今天启动项目看到已经启动起来,但是看到控制台有红色,没注意是什么问题,具体在细看下,发现是一个Tomcat中发布了两个实例. 解决办法:去发布路径下,全部删掉或者删掉不用的即可.

  8. Openstack-Namespaces

    介绍OpenStack neutron使用Linux网络命名空间来避免物理网络和虚拟网络间的冲突,或者不同虚拟网络间的冲突. 网络命名空间就是一个独立的网络协议栈,它有自己的网络接口,路由,以及防火墙 ...

  9. Linux 配置SSH 无密钥登陆

    根据SSH 协议,每次登陆必须输入密码,比较麻烦,SSH还提供了公钥登陆,可以省去输入密码的步骤. 公钥登陆:用户将自己的公钥存储在远程主机上,登陆的时候,远程主机会向用户发送一串随机字符串,用户用自 ...

  10. Java 五大原则

    1.单一职责 不论是在设计类,接口还是方法,单一职责都会处处体现,单一职责的定义:我们把职责定义为系统变化的原因.所有在定义类,接口,方法的时候.定义完以后再去想一想是不能多于一个的动机去改变这个类, ...