版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711
我主要分三篇文章给大家介绍tensorflow的损失函数,本篇为tensorflow自定义损失函数。 
(一)tensorflow内置的四个损失函数 
(二)其他损失函数 
(三)自定义损失函数
自定义损失函数是损失函数章节的结尾,学习自定义损失函数,对于提高分类分割等问题的准确率很有帮助,同时探索新型的损失函数也可以让你文章多多。这里我们介绍构建自定义损失函数的方法,并且介绍可以均衡正负例的loss,以及在多分类中可以解决样本数量不均衡的loss的方法。
首先为了有足够的知识学会自定义损失函数,我们需要知道tensorflow都能实现什么样的操作。其实答案是你常见的数学运算都可以,所以说只要你能把心中的损失函数表达为数学式的形式,那么你就能够将其转变为损失函数的形式。下面介绍一些常见的函数:
四则运算:tf.add(Tensor1,Tensor2),tf.sub(Tensor1,Tensor2), tf.mul(Tensor1,Tensor2),tf.div(Tensor1,Tensor2)这里的操作也可以被正常的加减乘除的负号所取代。这里想要指出的是乘法和除法的规则和numpy库是一样的,是matlab中的点乘而不是矩阵的乘法,矩阵的乘法是tf.matmul(Tensor1, Tensor2)
基础运算: 取模运算(tf.mod()),绝对值(tf.abs()),平方(tf.square()),四舍五入取整(tf.round()),取平方根(tf.sqrt()) ,自然对数的幂(tf.exp()) ,取对数(tf.log()) ,幂(tf.pow()) ,正弦运算(tf.sin())。以上的这些数学运算以及很多没有被提及的运算在tensorflow中都可以自己被求导,所以大家不用担心还需要自己写反向传播的问题,只要你的操作是由tensorflow封装的基础操作实现的,那么反向传播就可以自动的实现。
条件判断,通过条件判断语句我们就可以实现分段的损失函数,利用tf.where(condition, tensor_x, tensor_y) 如果说条件condition是True那么就会返回tensor_x,如果是False则返回tensor_y。注:旧版本的tensorflow可以用tf.select实现这个操作。
比较操作,为了获得condition这个参数,我们可以用tf.greater( tensor_x, tensor_y)如果说tensor_x大于tensor_y则返回True。
tf.reduce_sum(),tf.reduce_mean()这两个操作是重要的loss操作,因为loss是一个数字,而通常计算得到的是一个高维的矩阵,因此用降维加法和降维取平均,可以将一个高维的矩阵变为一个数字。
有了上面的这些操作,我们就可以实现基本的损失函数的构建了,比如我们构建交叉熵损失函数:
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)+(y_-1)* tf.log(1-y))) 
再介绍一个均衡正负样本数量的loss,首先在训练数据中,通常来讲正例比较少,负例比较多,如疾病筛查中,有病的占少数而绝大部分人是健康的,这种数量不均衡的数据可能会让分类器倾向于将所有的示例都分为健康人,因为这样整体的准确率可能就能达到90%以上,为此,可以用调整loss权重的方式来缓解样本数量不均衡的问题,如:
pos_ratio=num_of_positive/num_all # 病人占总体的比例,较小如0.1
neg_ratio=num_of_negative/num_all # 正常人占总体的比例,较大如0.9
cross_entropy = tf.reduce_mean(-neg_ratio*tf.reduce_sum(y_ * tf.log(y)+pos_ratio*(y_-1)* tf.log(1-y))) 
在这里我们给病人的损失项乘了一个较大的系数,使得一旦占少数的病人被错分为健康人的时候,代价就非常的大。同样的给正常人的损失项乘了一个较小的系数,使其诊断错误时对网络的影像较小。 
这也符合实际情况,即使健康人在筛查时被通知可能患病,只要再进一步检查就可以。但是如果在筛查的时候将病人误分为健康人,那么付出的就可能是生命的代价了。
以上是二分类的例子,那么在多分类的时候应该如何做呢?我们也可以通过乘系数这样的方式解决问题,这里我们认为标签是one_hot形式的如:
class1_weight=0.2 # 第一类的权重系数
class2_weight=0.5 # 第二类的权重系数
class3_weight=0.3 # 第三类的权重系数
cross_entropy = tf.reduce_mean(-class1_weight*tf.reduce
_sum(y_[:,0] * tf.log(y[:,0])-class2_weight*tf.reduce
_sum(y_[:,1] * tf.log(y[:,1])-class3_weight*tf.reduce
_sum(y_[:,2] * tf.log(y[:,2])) 
因为标签和预测的结果都是one_hot的形式,因此在这里y[:,0]就是第一类的概率值,其中第一个维度的长度是minibatch的大小。同理y[:,0]就是第二类的概率值,我们在不同的项上乘上不同类别的权重系数,就可以一定程度上解决样本数量不均衡所带来的困扰。
————————————————
版权声明:本文为CSDN博主「Liu-Kevin」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/limiyudianzi/article/details/80697711

Tensorflow 损失函数(loss function)及自定义损失函数(三)的更多相关文章

  1. 损失函数(Loss function) 和 代价函数(Cost function)

    1损失函数和代价函数的区别: 损失函数(Loss function):指单个训练样本进行预测的结果与实际结果的误差. 代价函数(Cost function):整个训练集,所有样本误差总和(所有损失函数 ...

  2. 损失函数(Loss Function) -1

    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...

  3. 损失函数(loss function)

    通常而言,损失函数由损失项(loss term)和正则项(regularization term)组成.发现一份不错的介绍资料: http://www.ics.uci.edu/~dramanan/te ...

  4. 损失函数(loss function) 转

    原文:http://luowei828.blog.163.com/blog/static/310312042013101401524824 通常而言,损失函数由损失项(loss term)和正则项(r ...

  5. 惩罚因子(penalty term)与损失函数(loss function)

    penalty term 和 loss function 看起来很相似,但其实二者完全不同. 惩罚因子: penalty term的作用是把受限优化问题转化为非受限优化问题. 比如我们要优化: min ...

  6. tensorflow2 自定义损失函数使用的隐藏坑

    Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制.当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型. ...

  7. [machine learning] Loss Function view

    [machine learning] Loss Function view 有关Loss Function(LF),只想说,终于写了 一.Loss Function 什么是Loss Function? ...

  8. loss function与cost function

    实际上,代价函数(cost function)和损失函数(loss function 亦称为 error function)是同义的.它们都是事先定义一个假设函数(hypothesis),通过训练集由 ...

  9. 【深度学习】一文读懂机器学习常用损失函数(Loss Function)

    最近太忙已经好久没有写博客了,今天整理分享一篇关于损失函数的文章吧,以前对损失函数的理解不够深入,没有真正理解每个损失函数的特点以及应用范围,如果文中有任何错误,请各位朋友指教,谢谢~ 损失函数(lo ...

随机推荐

  1. js学习之面向对象

    一.创建对象的方法 1. {} 字面量创建 var person ={ name: "lisi", age: , say: function(){ alert(this.name) ...

  2. 如何将一个div盒子水平垂直都居中?

    html代码如下: 固定样式: 方法一:利用定位(常用方法,推荐) .parent{ position:relative; } .child{ position:absolute; top:50%; ...

  3. Web前端2019面试总结2

    1.js继承: 想要继承,就必须要提供个父类(继承谁,提供继承的属性) 组合继承(组合原型链继承和借用构造函数继承)(常用) 重点:结合了两种模式的优点,传参和复用 特点:1.可以继承父类原型上的属性 ...

  4. FreeRTOS 任务通知模拟事件标志组

    实验 //设置事件位的任务 void eventsetbit_task(void *pvParameters) { u8 key; while(1) { if(EventGroupTask_Handl ...

  5. rpm安装与yum安装的区别

      linux下的安装包多为rpm安装包.通常安装方法为 rpm -ivh 包的路径+包名.rpm 其中参数-i为安装 -v显示信息 -h显示进度条.这三个参数基本捆绑使用rpm的路径不单可以是本地磁 ...

  6. springboot+内置改为外置tomcat

    1.pom.xml springboot项目利用的是自己内置的tomcat,这边就是不依赖内置的tomcat,将其编译的作用域设置为provided <dependency> <gr ...

  7. 如何预防SQL注入

    归纳一下,主要有以下几点: 1.永远不要信任用户的输入.对用户的输入进行校验,可以通过正则表达式,或限制长度:对单引号和 双"-"进行转换等. 2.永远不要使用动态拼装sql,可以 ...

  8. Httpd服务入门知识-Httpd服务常见配置案例之基于客户端来源地址实现访问控制

    Httpd服务入门知识-Httpd服务常见配置案例之基于客户端来源地址实现访问控制 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.Options  1>.OPTIONS指 ...

  9. Java基础--static关键字

    不管是平时阅读源代码,还是笔试.面试中,static关键字还是经常被问道,这篇文章主要来重新复习一下该关键字. 一.static用途 static方便在没有创建对象的时候调用方法或者变量. stati ...

  10. java多线程的几种实现方式

    java多线程的几种实现方式 1.继承Thread类,重写run方法2.实现Runnable接口,重写run方法,实现Runnable接口的实现类的实例对象作为Thread构造函数的target3.通 ...