版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711
我主要分三篇文章给大家介绍tensorflow的损失函数,本篇为tensorflow自定义损失函数。 
(一)tensorflow内置的四个损失函数 
(二)其他损失函数 
(三)自定义损失函数
自定义损失函数是损失函数章节的结尾,学习自定义损失函数,对于提高分类分割等问题的准确率很有帮助,同时探索新型的损失函数也可以让你文章多多。这里我们介绍构建自定义损失函数的方法,并且介绍可以均衡正负例的loss,以及在多分类中可以解决样本数量不均衡的loss的方法。
首先为了有足够的知识学会自定义损失函数,我们需要知道tensorflow都能实现什么样的操作。其实答案是你常见的数学运算都可以,所以说只要你能把心中的损失函数表达为数学式的形式,那么你就能够将其转变为损失函数的形式。下面介绍一些常见的函数:
四则运算:tf.add(Tensor1,Tensor2),tf.sub(Tensor1,Tensor2), tf.mul(Tensor1,Tensor2),tf.div(Tensor1,Tensor2)这里的操作也可以被正常的加减乘除的负号所取代。这里想要指出的是乘法和除法的规则和numpy库是一样的,是matlab中的点乘而不是矩阵的乘法,矩阵的乘法是tf.matmul(Tensor1, Tensor2)
基础运算: 取模运算(tf.mod()),绝对值(tf.abs()),平方(tf.square()),四舍五入取整(tf.round()),取平方根(tf.sqrt()) ,自然对数的幂(tf.exp()) ,取对数(tf.log()) ,幂(tf.pow()) ,正弦运算(tf.sin())。以上的这些数学运算以及很多没有被提及的运算在tensorflow中都可以自己被求导,所以大家不用担心还需要自己写反向传播的问题,只要你的操作是由tensorflow封装的基础操作实现的,那么反向传播就可以自动的实现。
条件判断,通过条件判断语句我们就可以实现分段的损失函数,利用tf.where(condition, tensor_x, tensor_y) 如果说条件condition是True那么就会返回tensor_x,如果是False则返回tensor_y。注:旧版本的tensorflow可以用tf.select实现这个操作。
比较操作,为了获得condition这个参数,我们可以用tf.greater( tensor_x, tensor_y)如果说tensor_x大于tensor_y则返回True。
tf.reduce_sum(),tf.reduce_mean()这两个操作是重要的loss操作,因为loss是一个数字,而通常计算得到的是一个高维的矩阵,因此用降维加法和降维取平均,可以将一个高维的矩阵变为一个数字。
有了上面的这些操作,我们就可以实现基本的损失函数的构建了,比如我们构建交叉熵损失函数:
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)+(y_-1)* tf.log(1-y))) 
再介绍一个均衡正负样本数量的loss,首先在训练数据中,通常来讲正例比较少,负例比较多,如疾病筛查中,有病的占少数而绝大部分人是健康的,这种数量不均衡的数据可能会让分类器倾向于将所有的示例都分为健康人,因为这样整体的准确率可能就能达到90%以上,为此,可以用调整loss权重的方式来缓解样本数量不均衡的问题,如:
pos_ratio=num_of_positive/num_all # 病人占总体的比例,较小如0.1
neg_ratio=num_of_negative/num_all # 正常人占总体的比例,较大如0.9
cross_entropy = tf.reduce_mean(-neg_ratio*tf.reduce_sum(y_ * tf.log(y)+pos_ratio*(y_-1)* tf.log(1-y))) 
在这里我们给病人的损失项乘了一个较大的系数,使得一旦占少数的病人被错分为健康人的时候,代价就非常的大。同样的给正常人的损失项乘了一个较小的系数,使其诊断错误时对网络的影像较小。 
这也符合实际情况,即使健康人在筛查时被通知可能患病,只要再进一步检查就可以。但是如果在筛查的时候将病人误分为健康人,那么付出的就可能是生命的代价了。
以上是二分类的例子,那么在多分类的时候应该如何做呢?我们也可以通过乘系数这样的方式解决问题,这里我们认为标签是one_hot形式的如:
class1_weight=0.2 # 第一类的权重系数
class2_weight=0.5 # 第二类的权重系数
class3_weight=0.3 # 第三类的权重系数
cross_entropy = tf.reduce_mean(-class1_weight*tf.reduce
_sum(y_[:,0] * tf.log(y[:,0])-class2_weight*tf.reduce
_sum(y_[:,1] * tf.log(y[:,1])-class3_weight*tf.reduce
_sum(y_[:,2] * tf.log(y[:,2])) 
因为标签和预测的结果都是one_hot的形式,因此在这里y[:,0]就是第一类的概率值,其中第一个维度的长度是minibatch的大小。同理y[:,0]就是第二类的概率值,我们在不同的项上乘上不同类别的权重系数,就可以一定程度上解决样本数量不均衡所带来的困扰。
————————————————
版权声明:本文为CSDN博主「Liu-Kevin」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711

Tensorflow 损失函数(loss function)及自定义损失函数(三)的更多相关文章

  1. 损失函数(Loss function) 和 代价函数(Cost function)

    1损失函数和代价函数的区别: 损失函数(Loss function):指单个训练样本进行预测的结果与实际结果的误差. 代价函数(Cost function):整个训练集,所有样本误差总和(所有损失函数 ...

  2. 损失函数(Loss Function) -1

    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...

  3. 损失函数(loss function)

    通常而言,损失函数由损失项(loss term)和正则项(regularization term)组成.发现一份不错的介绍资料: http://www.ics.uci.edu/~dramanan/te ...

  4. 损失函数(loss function) 转

    原文:http://luowei828.blog.163.com/blog/static/310312042013101401524824 通常而言,损失函数由损失项(loss term)和正则项(r ...

  5. 惩罚因子(penalty term)与损失函数(loss function)

    penalty term 和 loss function 看起来很相似,但其实二者完全不同. 惩罚因子: penalty term的作用是把受限优化问题转化为非受限优化问题. 比如我们要优化: min ...

  6. tensorflow2 自定义损失函数使用的隐藏坑

    Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制.当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型. ...

  7. [machine learning] Loss Function view

    [machine learning] Loss Function view 有关Loss Function(LF),只想说,终于写了 一.Loss Function 什么是Loss Function? ...

  8. loss function与cost function

    实际上,代价函数(cost function)和损失函数(loss function 亦称为 error function)是同义的.它们都是事先定义一个假设函数(hypothesis),通过训练集由 ...

  9. 【深度学习】一文读懂机器学习常用损失函数(Loss Function)

    最近太忙已经好久没有写博客了,今天整理分享一篇关于损失函数的文章吧,以前对损失函数的理解不够深入,没有真正理解每个损失函数的特点以及应用范围,如果文中有任何错误,请各位朋友指教,谢谢~ 损失函数(lo ...

随机推荐

  1. 如何检测Windows中的横向渗透攻击

    一.前言 横向渗透攻击技术是复杂网络攻击中广泛使用的一种技术,特别是在高级持续威胁(Advanced Persistent Threats,APT)中更加热衷于使用这种攻击方法.攻击者可以利用这些技术 ...

  2. Nexus6p手机root和安装xposed

    进行root前需要两个前提条件 解锁OEM 进入开发者选项:设置-〉关于-〉一直点版本号会出现,usb调试打开 手机连接pc命令行输入: adb reboot bootloader 进入bootloa ...

  3. Kafka Streams开发入门(3)

    背景 上一篇我们介绍了Kafka Streams中的消息过滤操作filter,今天我们展示一个对消息进行转换Key的操作,依然是结合一个具体的实例展开介绍.所谓转换Key是指对流处理中每条消息的Key ...

  4. 微服务——SpringCloud(Eureka注册中心搭建)

    IDE:IDEA,说实话,真不怎么喜欢用Eclipse这个IDE,太锻炼人了 配置模式:Grandle 微服务框架:SpringCloud 第一步 创建一个Spring Initializr项目 第二 ...

  5. mysql 的逻辑架构 与 存储引擎的介绍

    mysql 的逻辑架构分为三层: 最上层的服务大多数基于网络的客户端.服务器的工具或者服务都有类似的架构,比如连接处理,授权认证.安全等 第二层架构:mysql的核心服务功能都在这一层,包括查询解析, ...

  6. 洛谷 P1522 牛的旅行 Cow Tours

    题目链接:https://www.luogu.org/problem/P1522 思路:编号,然后跑floyd,这是很清楚的.然后记录每个点在这个联通块中的最远距离. 然后分连通块,枚举两个点(不属于 ...

  7. Educational Codeforces Round 65 (Rated for Div. 2)题解

    Educational Codeforces Round 65 (Rated for Div. 2)题解 题目链接 A. Telephone Number 水题,代码如下: Code #include ...

  8. httprunner学习13-环境变量.env

    前言 一般来说,在进行实际应用的开发过程中,应用会拥有不同的运行环境,通常会有以下环境: 本地开发环境 测试环境 生产环境 在不同环境中,我们可能会使用不同的数据库或邮件发送驱动等配置,这时候则需要通 ...

  9. BZOJ5509: [Tjoi2019]甲苯先生的滚榜

    题解 开n个平衡树对每个AC数维护罚时,然后不同AC数用树状数组维护即可. 其实挺好写的...就是评测的时候评的巨久... #include <bits/stdc++.h> using n ...

  10. 《快活帮》第九次团队作业:【Beta】Scrum meeting 2

    项目 内容 这个作业属于哪个课程 2016计算机科学与工程学院软件工程(西北师范大学) 这个作业的要求在哪里 实验十三 团队作业9:BETA冲刺与团队项目验收 团队名称 快活帮 作业学习目标 (1)掌 ...