版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711
我主要分三篇文章给大家介绍tensorflow的损失函数,本篇为tensorflow自定义损失函数。 
(一)tensorflow内置的四个损失函数 
(二)其他损失函数 
(三)自定义损失函数
自定义损失函数是损失函数章节的结尾,学习自定义损失函数,对于提高分类分割等问题的准确率很有帮助,同时探索新型的损失函数也可以让你文章多多。这里我们介绍构建自定义损失函数的方法,并且介绍可以均衡正负例的loss,以及在多分类中可以解决样本数量不均衡的loss的方法。
首先为了有足够的知识学会自定义损失函数,我们需要知道tensorflow都能实现什么样的操作。其实答案是你常见的数学运算都可以,所以说只要你能把心中的损失函数表达为数学式的形式,那么你就能够将其转变为损失函数的形式。下面介绍一些常见的函数:
四则运算:tf.add(Tensor1,Tensor2),tf.sub(Tensor1,Tensor2), tf.mul(Tensor1,Tensor2),tf.div(Tensor1,Tensor2)这里的操作也可以被正常的加减乘除的负号所取代。这里想要指出的是乘法和除法的规则和numpy库是一样的,是matlab中的点乘而不是矩阵的乘法,矩阵的乘法是tf.matmul(Tensor1, Tensor2)
基础运算: 取模运算(tf.mod()),绝对值(tf.abs()),平方(tf.square()),四舍五入取整(tf.round()),取平方根(tf.sqrt()) ,自然对数的幂(tf.exp()) ,取对数(tf.log()) ,幂(tf.pow()) ,正弦运算(tf.sin())。以上的这些数学运算以及很多没有被提及的运算在tensorflow中都可以自己被求导,所以大家不用担心还需要自己写反向传播的问题,只要你的操作是由tensorflow封装的基础操作实现的,那么反向传播就可以自动的实现。
条件判断,通过条件判断语句我们就可以实现分段的损失函数,利用tf.where(condition, tensor_x, tensor_y) 如果说条件condition是True那么就会返回tensor_x,如果是False则返回tensor_y。注:旧版本的tensorflow可以用tf.select实现这个操作。
比较操作,为了获得condition这个参数,我们可以用tf.greater( tensor_x, tensor_y)如果说tensor_x大于tensor_y则返回True。
tf.reduce_sum(),tf.reduce_mean()这两个操作是重要的loss操作,因为loss是一个数字,而通常计算得到的是一个高维的矩阵,因此用降维加法和降维取平均,可以将一个高维的矩阵变为一个数字。
有了上面的这些操作,我们就可以实现基本的损失函数的构建了,比如我们构建交叉熵损失函数:
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)+(y_-1)* tf.log(1-y))) 
再介绍一个均衡正负样本数量的loss,首先在训练数据中,通常来讲正例比较少,负例比较多,如疾病筛查中,有病的占少数而绝大部分人是健康的,这种数量不均衡的数据可能会让分类器倾向于将所有的示例都分为健康人,因为这样整体的准确率可能就能达到90%以上,为此,可以用调整loss权重的方式来缓解样本数量不均衡的问题,如:
pos_ratio=num_of_positive/num_all # 病人占总体的比例,较小如0.1
neg_ratio=num_of_negative/num_all # 正常人占总体的比例,较大如0.9
cross_entropy = tf.reduce_mean(-neg_ratio*tf.reduce_sum(y_ * tf.log(y)+pos_ratio*(y_-1)* tf.log(1-y))) 
在这里我们给病人的损失项乘了一个较大的系数,使得一旦占少数的病人被错分为健康人的时候,代价就非常的大。同样的给正常人的损失项乘了一个较小的系数,使其诊断错误时对网络的影像较小。 
这也符合实际情况,即使健康人在筛查时被通知可能患病,只要再进一步检查就可以。但是如果在筛查的时候将病人误分为健康人,那么付出的就可能是生命的代价了。
以上是二分类的例子,那么在多分类的时候应该如何做呢?我们也可以通过乘系数这样的方式解决问题,这里我们认为标签是one_hot形式的如:
class1_weight=0.2 # 第一类的权重系数
class2_weight=0.5 # 第二类的权重系数
class3_weight=0.3 # 第三类的权重系数
cross_entropy = tf.reduce_mean(-class1_weight*tf.reduce
_sum(y_[:,0] * tf.log(y[:,0])-class2_weight*tf.reduce
_sum(y_[:,1] * tf.log(y[:,1])-class3_weight*tf.reduce
_sum(y_[:,2] * tf.log(y[:,2])) 
因为标签和预测的结果都是one_hot的形式,因此在这里y[:,0]就是第一类的概率值,其中第一个维度的长度是minibatch的大小。同理y[:,0]就是第二类的概率值,我们在不同的项上乘上不同类别的权重系数,就可以一定程度上解决样本数量不均衡所带来的困扰。
————————————————
版权声明:本文为CSDN博主「Liu-Kevin」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711

Tensorflow 损失函数(loss function)及自定义损失函数(三)的更多相关文章

  1. 损失函数(Loss function) 和 代价函数(Cost function)

    1损失函数和代价函数的区别: 损失函数(Loss function):指单个训练样本进行预测的结果与实际结果的误差. 代价函数(Cost function):整个训练集,所有样本误差总和(所有损失函数 ...

  2. 损失函数(Loss Function) -1

    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...

  3. 损失函数(loss function)

    通常而言,损失函数由损失项(loss term)和正则项(regularization term)组成.发现一份不错的介绍资料: http://www.ics.uci.edu/~dramanan/te ...

  4. 损失函数(loss function) 转

    原文:http://luowei828.blog.163.com/blog/static/310312042013101401524824 通常而言,损失函数由损失项(loss term)和正则项(r ...

  5. 惩罚因子(penalty term)与损失函数(loss function)

    penalty term 和 loss function 看起来很相似,但其实二者完全不同. 惩罚因子: penalty term的作用是把受限优化问题转化为非受限优化问题. 比如我们要优化: min ...

  6. tensorflow2 自定义损失函数使用的隐藏坑

    Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制.当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型. ...

  7. [machine learning] Loss Function view

    [machine learning] Loss Function view 有关Loss Function(LF),只想说,终于写了 一.Loss Function 什么是Loss Function? ...

  8. loss function与cost function

    实际上,代价函数(cost function)和损失函数(loss function 亦称为 error function)是同义的.它们都是事先定义一个假设函数(hypothesis),通过训练集由 ...

  9. 【深度学习】一文读懂机器学习常用损失函数(Loss Function)

    最近太忙已经好久没有写博客了,今天整理分享一篇关于损失函数的文章吧,以前对损失函数的理解不够深入,没有真正理解每个损失函数的特点以及应用范围,如果文中有任何错误,请各位朋友指教,谢谢~ 损失函数(lo ...

随机推荐

  1. React Context API

    使用React 开发程序的时候,组件中的数据共享是通过数据提升,变成父组件中的属性,然后再把属性向下传递给子组件来实现的.但当程序越来越复杂,需要共享的数据也越来越多,最后可能就把共享数据直接提升到最 ...

  2. underscore_1: map()

    map()是underscore.js中一个处理数组和对象的方法. params: 1. array || obj 2. callback 3. content 上下文指向 使用: var obj = ...

  3. 爬虫requests库 之爬虫贴吧

    首先要观察爬虫的URL规律,爬取一个贴吧所有页的数据,观察点击下一页时URL是如何变化的. 思路: 定义一个类,初始化方法什么都不用管 定义一个run方法,用来实现主要逻辑 3 class Tieba ...

  4. 创建简易的SpringBoot项目

    创建简易的SpringBoot项目 这两天在学习springboot,菜鸟刚刚知道这个东西,看着springboot项目下那一大堆目录都不知道从何下手,还是静下心来从最简单的创建一个项目入手,这路和大 ...

  5. DLL Injection with Delphi(转载)

    原始链接 I had recently spent some time playing around with the simple to use DelphiDetours package from ...

  6. Windows Server 2012 R2 配置IIS

    efs:http://www.07net01.com/storage_networking/windows_server_2012_anzhuang_IIS8_bingzhichi_asp_45191 ...

  7. java 获取当前方法的被调用信息(被那个方法那个类那一行调用)

    public void testMethod(){ Test1 t1 = new Test1(); t1.my(); } public static void main(String[] args) ...

  8. IDEA创建本地Spark程序,并本地运行

    1   IDEA创建maven项目进行测试 v创建一个新项目,步骤如下: 选择“Enable Auto-Import”,加载完后:选择“Enable Auto-Import”,加载完后: 添加SDK依 ...

  9. Python:基础复习

    一.数据类型 对象的三大特征:值.身份.类型: 1)数字 Number 整型.浮点型 只有 int 和 float 两种类型: type(2/2):float 类型:2/2 == 1.0: type( ...

  10. Werkzeug——python web开发工具包

    转载请注明原文地址:https://www.cnblogs.com/ygj0930/p/10826062.html 一:Werkzeug是个啥 1)Werkzeug是一个工具包,它封装了很多东西,诸如 ...