weighted_cross_entropy_with_logits
weighted_cross_entropy_with_logits
觉得有用的话,欢迎一起讨论相互学习~




weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
此函数功能以及计算方式基本与tf_nn_sigmoid_cross_entropy_with_logits差不多,但是加上了权重的功能,是计算具有权重的sigmoid交叉熵函数
计算方法 :
\]
官方文档定义及推导过程:
通常的cross-entropy交叉熵函数定义如下:
(1 - targets) * -log(1 - sigmoid(logits))\]
对于加了权值pos_weight的交叉熵函数:
(1 - targets) * -log(1 - sigmoid(logits))\]
现在我们使用 x = logits, z = targets, q = pos_weight的代数式
The loss is:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
我们把l = (1 + (q - 1) * z), 来确保稳定性并且比避免溢出,公式为:
\]
logitsandtargets必须要有相同的数据类型和shape.
参数:
_sentinel:本质上是不用的参数,不用填
targets:一个和logits具有相同的数据类型(type)和尺寸形状(shape)的张量(tensor)
shape:[batch_size,num_classes],单样本是[num_classes]
logits:一个数据类型(type)是float32或float64的张量
pos_weight:正样本的一个系数
name:操作的名字,可填可不填
实例代码
import numpy as np
import tensorflow as tf
input_data = tf.Variable(np.random.rand(3, 3), dtype=tf.float32)
# np.random.rand()传入一个shape,返回一个在[0,1)区间符合均匀分布的array
output = tf.nn.weighted_cross_entropy_with_logits(logits=input_data,
targets=[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]],
pos_weight=2.0)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(output))
# [[ 1.04947078 0.89594436 0.92146152]
# [ 0.70252579 1.00673866 1.08856964]
# [ 1.07195592 1.18525708 1.04106498]]
weighted_cross_entropy_with_logits的更多相关文章
- TF Boys (TensorFlow Boys ) 养成记(五)
有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...
- TensorFlow 常用函数汇总
本文介绍了tensorflow的常用函数,源自网上整理. TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CPU ...
- 基于 TensorFlow 在手机端实现文档检测
作者:冯牮 前言 本文不是神经网络或机器学习的入门教学,而是通过一个真实的产品案例,展示了在手机客户端上运行一个神经网络的关键技术点 在卷积神经网络适用的领域里,已经出现了一些很经典的图像分类网络,比 ...
- TensorFlow 常用函数与方法
摘要:本文主要对tf的一些常用概念与方法进行描述. tf函数 TensorFlow 将图形定义转换成分布式执行的操作, 以充分利用可用的计算资源(如 CPU 或 GPU.一般你不需要显式指定使用 CP ...
- TensorFlow机器学习实战指南之第二章
一.计算图中的操作 在这个例子中,我们将结合前面所学的知识,传入一个列表到计算图中的操作,并打印返回值: 声明张量和占位符.这里,创建一个numpy数组,传入计算图操作: import tensorf ...
- Tensorflow一些常用基本概念与函数
1.tensorflow的基本运作 为了快速的熟悉TensorFlow编程,下面从一段简单的代码开始: import tensorflow as tf #定义‘符号’变量,也称为占位符 a = tf. ...
- Tensorflow一些常用基本概念与函数(1)
为了快速的熟悉TensorFlow编程,下面从一段简单的代码开始: import tensorflow as tf #定义‘符号’变量,也称为占位符 a = tf.placeholder(" ...
- 『TensorFlow』函数查询列表_神经网络相关
tf.Graph 操作 描述 class tf.Graph tensorflow中的计算以图数据流的方式表示一个图包含一系列表示计算单元的操作对象以及在图中流动的数据单元以tensor对象表现 tf. ...
- 『TensorFlow』网络操作API_中_损失函数及分类器
一.误差值 度量两个张量或者一个张量和零之间的损失误差,这个可用于在一个回归任务或者用于正则的目的(权重衰减). l2_loss tf.nn.l2_loss(t, name=None) 解释:这个函数 ...
随机推荐
- 关于python使用threadpool中的函数单个参数和多个参数用法举例
1.对单个元素的函数使用线程池: # encoding:utf-8 __author__='xijun.gong' import threadpool def func(name): print 'h ...
- Java与算法之(1) - 冒泡排序
冒泡排序法的原理是,每次比较相邻的两个元素,如果它们的顺序错误就把它们交换过来. 例如对4 3 6 2 7 1 5这7个数字进行从小到大的排序,从最左侧开始,首先比较4和3 因为是从小到大排序,4和3 ...
- 2017"百度之星"程序设计大赛 - 资格赛【1001 Floyd求最小环 1002 歪解(并查集),1003 完全背包 1004 01背包 1005 打表找规律+卡特兰数】
度度熊保护村庄 Accepts: 13 Submissions: 488 Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 32768/3276 ...
- Gym 100952A&&2015 HIAST Collegiate Programming Contest A. Who is the winner?【字符串,暴力】
A. Who is the winner? time limit per test:1 second memory limit per test:64 megabytes input:standard ...
- Java学习之类的构建方法(函数)
在学习类的部分时,建立一个对象是这样建立的:(假设Person是类)Person p = new Person():我一直很费解为何new后面是一个函数形式, 今天学完构建方法后,才恍然大悟,豁然 ...
- PHPStorm+PHPStudy配置XDebug
img { max-width: 100% } 上一节里面从PHPStudy+PHPStorm的配置,到最后发布,PHPStorm只是承担了编辑器和发布站点的任务,但是还没有办法像Visual Stu ...
- [国嵌笔记][024][ARM汇编编程概述]
汇编程序用途 1.在bootloader与内核初始化时,还没有建立C语言运行环境,需要用到汇编程序 2.在对访问效率要求很高的情况下,需要用到汇编程序 ARM汇编分类 1.ARM标准汇编:适合于Win ...
- C# 小笔记
1,Using using (var ws = new WebSocket ("ws://dragonsnest.far/Laputa")) { ws.OnMessage += ( ...
- Java calendar获取月份注意事项
Calendar中月份month得取值是从0开始,到11,对应着日历中的1-12月.所以在用此取月份的话,需要在原有基础上加1.
- Linux Server release 7.3 更换阿里网络yum源
查看当前系统下的yum源 [root@localhost ~]# rpm -qa |grep yum yum-3.4.3-150.el7.noarch yum-utils-1.1.31-40.el7. ...