cross entropy

交叉熵,tensorflow  对 cross entropy 进行了集成:

1. 二分类和多分类公式集成,共用一个 API;

p(x) 真实标签,q(x) 预测概率;

2. 把 sigmoid 、softmax 等集成到 cross entropy 中;

正常情况下,神经网络最后的输出需要通过 softmax 转换成概率,然后再套用公式计算交叉熵,tf 的集成 API 直接输入神经网络的输出即可

tf.nn.softmax_cross_entropy_with_logits

集成了 softmax 和 cross entropy 的 API

def softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
labels=None,
logits=None,
dim=-1,
name=None,
axis=None)

示例

#our NN's output
logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y = tf.nn.softmax(logits) #true label
y_= tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]]) #step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #do cross_entropy just one step
cross_entropy2 = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_)) # dont forget tf.reduce_sum()!! with tf.Session() as sess:
softmax_value=sess.run(y)
c_e = sess.run(cross_entropy)
c_e2 = sess.run(cross_entropy2)
print(softmax_value)
print(c_e) # 1.222818
print(c_e2) # 1.2228179

可以看到手动计算 和 API 计算的结果是一样的

tf.nn.sparse_softmax_cross_entropy_with_logits

API 参数同上;

sparse,稀疏编码,把类别进行稀疏编码,如共 3 个类别,样本属于第 2 个,则需要编码为 [0,1,0];    【对实际 label 的 sparse】

集成了 稀疏编码、softmax 和交叉熵;

# our NN's output
logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]]) # true label
# 注意这里标签必须是浮点数,不然在后面计算tf.multiply时就会因为类型不匹配tf_log的float32数据类型而出错
y_= tf.constant([[0,0,1.0],[0,0,1.0],[0,0,1.0]]) # 这个是稀疏的标签 # 手算交叉熵
y = tf.nn.softmax(logits)
tf_log = tf.log(y)
pixel_wise_mult = tf.multiply(y_,tf_log)
cross_entropy = -tf.reduce_sum(pixel_wise_mult) #将标签稠密化
dense_y = tf.argmax(y_,1) # [2 2 2]
cross_entropy2_step1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=dense_y,logits=logits)
cross_entropy2_step2 = tf.reduce_sum(cross_entropy2_step1) with tf.Session() as sess:
cross_entropy_value=sess.run(cross_entropy)
sparse_cross_entropy2_step2_value=sess.run([cross_entropy2_step2])
print(sess.run(dense_y)) # [2 2 2]
print("step4:cross_entropy result=\n%s\n"%(cross_entropy_value)) # 1.222818
print("Function(tf.reduce_sum) result=\n%s\n"%(sparse_cross_entropy2_step2_value)) # 1.2228179

tf.nn.sigmoid_cross_entropy_with_logits

API 参数同上;

这个 API 适用于 一个样本有多个 label 的情况,如在目标检测中,一张图像上可能有 猫,可能有狗,输出的 label 可能为 [0,1,1,0];

它的本质不是多分类,而是多个二分类;

def sigmoid(x):
return 1.0/(1+np.exp(-x)) labels = np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
logits = np.array([[11.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
y_pred = sigmoid(logits)
prob_error1 = -labels * np.log(y_pred) - (1 - labels) * np.log(1 - y_pred) labels1 = np.array([[0.,1.,0.],[1.,1.,0.],[0.,0.,1.]]) # 不一定只属于一个类别
logits1 = np.array([[1.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
y_pred1 = sigmoid(logits1)
prob_error11 = -labels1 * np.log(y_pred1) - (1 - labels1) * np.log(1 - y_pred1) with tf.Session() as sess:
print(prob_error1)
# [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
# [1.00000454e+01 8.31528373e-07 3.04858735e+00]
# [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
# [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
# [1.00000454e+01 8.31528373e-07 3.04858735e+00]
# [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
print(prob_error11)
# [[1.31326169e+00 3.35406373e-04 7.00091147e+00]
# [4.53988992e-05 8.31528373e-07 3.04858735e+00]
# [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels1,logits=logits1)))
### 同上

tf.nn.weighted_cross_entropy_with_logits

它是 sigmoid_cross_entropy_with_logits 的扩展

def weighted_cross_entropy_with_logits(labels=None,
logits=None,
pos_weight=None,
name=None,
targets=None):
"""Computes a weighted cross entropy. labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits)) pos_weight: A coefficient to use on the positive examples
"""

tf.losses.softmax_cross_entropy

增加了一个权重,当权重为 1 时,等价于 tf.nn.softmax_cross_entropy_with_logits

#our NN's output
logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y = tf.nn.softmax(logits) #true label
y_= tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]]) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
# tf.losses.softmax_cross_entropy(y_, logits, weights=1.)
tf.losses.softmax_cross_entropy(y_, logits, weights=0.5) with tf.Session() as sess:
print(sess.run(loss)) # 0.40760598
print(sess.run(tf.losses.get_total_loss())) # 0.40760598 weights=1 时想等, weights=0.5 时为 0.20380299

均方差

tensorflow 其实没有提供这个 API,自己实现也很方便

y = tf.constant([0.9, 2.1, 2.8])
y_pred = tf.constant([1, 2, 3], dtype=tf.float32)
err1 = tf.reduce_sum(tf.square(y - y_pred)) / 3
err2 = tf.reduce_mean(tf.square(y - y_pred)) sess = tf.Session()
print(sess.run(err1)) # 0.020000001
print(sess.run(err2)) # 0.020000001

参考资料:

https://blog.csdn.net/marsjhao/article/details/72630147

https://blog.csdn.net/weixin_42561002/article/details/87802096  tf.losses.softmax_cross_entropy()及相邻函数中weights参数的设置

tf.trainable_variables() and tf.all_variables()的更多相关文章

  1. tf.trainable_variables和tf.all_variables的对比

    tf.trainable_variables返回的是可以用来训练的变量列表 tf.all_variables返回的是所有变量的列表

  2. tf.trainable_variables()

    https://blog.csdn.net/shwan_ma/article/details/78879620 一般来说,打印tensorflow变量的函数有两个:tf.trainable_varia ...

  3. tf.variable和tf.get_Variable以及tf.name_scope和tf.variable_scope的区别

    在训练深度网络时,为了减少需要训练参数的个数(比如具有simase结构的LSTM模型).或是多机多卡并行化训练大数据大模型(比如数据并行化)等情况时,往往需要共享变量.另外一方面是当一个深度学习模型变 ...

  4. 【TensorFlow基础】tf.add 和 tf.nn.bias_add 的区别

    1. tf.add(x,  y, name) Args: x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, ...

  5. TensorFlow 辨异 —— tf.placeholder 与 tf.Variable

    https://blog.csdn.net/lanchunhui/article/details/61712830 https://www.cnblogs.com/silence-tommy/p/70 ...

  6. TF.VARIABLE、TF.GET_VARIABLE、TF.VARIABLE_SCOPE以及TF.NAME_SCOPE关系

    1. tf.Variable与tf.get_variable tensorflow提供了通过变量名称来创建或者获取一个变量的机制.通过这个机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要 ...

  7. tensorflow笔记4:函数:tf.assign()、tf.assign_add()、tf.identity()、tf.control_dependencies()

    函数原型: tf.assign(ref, value, validate_shape=None, use_locking=None, name=None)   Defined in tensorflo ...

  8. 理解 tf.Variable、tf.get_variable以及范围命名方法tf.variable_scope、tf.name_scope

    tensorflow提供了通过变量名称来创建或者获取一个变量的机制.通过这个机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要将变量通过参数的形式到处传递. 1. tf.Variable( ...

  9. TF:利用TF的train.Saver将训练好的variables(W、b)保存到指定的index、meda文件—Jason niu

    import tensorflow as tf import numpy as np W = tf.Variable([[2,1,8],[1,2,5]], dtype=tf.float32, name ...

随机推荐

  1. css图片的全屏显示代码-css3

    <!DOCTYPE html><html lang="en"> <head> <meta charset="UTF-8" ...

  2. 恶意代码分析-使用apataDNS+inetsim模拟网络环境

    准备工作 虚拟机安装: Win7 Ubuntu apateDNS 密码:wplo inetsim 密码:ghla 客户端Win7需要做的工作 安装apateDNS 服务器端Ubuntu需要做的工作 下 ...

  3. android------DDMS files not found: tools\hprof-conv.exe

    好久没有Eclipse了,使用一下就遇到坑,使用eclipse突然发生这个问题:DDMS files not found: ***\sdk\tools\hprof-conv.exe,无法连接模拟器 在 ...

  4. mysql防注入

    1.对用户输入的数据进行过滤 2.永远不要使用动态拼装sql,可以使用参数化的sql或者直接使用存储过程进行数据查询存取. 3.永远不要使用管理员权限的数据库连接,为每个应用使用单独的权限有限的数据库 ...

  5. vue核心之响应式原理(双向绑定/数据驱动)

    实例化一个vue对象时, Observer类将每个目标对象(即data)的键值转换成getter/setter形式,用于进行依赖收集以及调度更新. Observer src/core/observer ...

  6. dedecms自定义表单时间时间戳值类型的转换方法

    找网站找的别人的方法,记录一下 修改/dede/templets/diy_list.htm,在第42行else前面加上以下代码: else if($fielddata[1]=='datetime') ...

  7. https请求排错过程

    1. 看请求有没有到nginx 此时需要查看nginx的日志.一般每一个项目都会配置一个nginx站点,而一个站点都会又一个nginx配置文件,这个文件位于哪里呢?不出意外应该在:下面,如果找不到的话 ...

  8. 优先队列优化dij算法

    之前已经弄过模板了,但那个复杂一点,这个就是裸的dij,用起来更方便 输入格式:n,m,s,d分别是点数,边数,起点,终点 之后m行,输入x,y,z分别是两点即权值 题目链接:https://www. ...

  9. MySQL查询性能调优化

    一.索引的概念 索引:类似于字典的目录,设置索引可以 加速数据查找,对数据进行约束: 二.索引类型: 主键索引:保证数据唯一性,不能重复+不能为空 普通索引:加速数据查找 唯一索引:加速查找+不能重复 ...

  10. nginx配置location总结及rewrite规则写法(2)

    2. Rewrite规则 rewrite功能就是,使用nginx提供的全局变量或自己设置的变量,结合正则表达式和标志位实现url重写以及重定向.rewrite只能放在server{},location ...