这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变)
我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元。
如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的)。
显然,我宁愿预估多了,也不想预估少了。
所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样。
(yhat沿用吴恩达课堂中的叫法)
 
import tensorflow as tf
from numpy.random import RandomState
batch_size = 8
# 两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
# 回归问题一般只有一个输出节点
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
# 定义了一个单层的神经网络前向传播的过程,这里就是简单加权和
w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)
# 定义预测多了和预测少了的成本
loss_less = 10
loss_more = 1
#在windows下,下面用这个where替代,因为调用tf.select会报错
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_)*loss_more, (y_-y)*loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
"""
设置回归的正确值为两个输入的和加上一个随机量,之所以要加上一个随机量是
为了加入不可预测的噪音,否则不同损失函数的意义就不大了,因为不同损失函数
都会在能完全预测正确的时候最低。一般来说,噪音为一个均值为0的小量,所以
这里的噪音设置为-0.05, 0.05的随机数。
"""
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
steps = 5000
for i in range(steps):
start = (i * batch_size) % dataset_size
end = min(start + batch_size, dataset_size)
sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
print(sess.run(w1))
[[ 1.01934695]
[ 1.04280889]
最终结果如上面所示。
因为我们当初生成训练数据的时候,y是x1 + x2,所以回归结果应该是1,1才对。
但是,由于我们加了自己定义的损失函数,所以,倾向于预估多一点。

如果,我们将loss_less和loss_more对调,我们看一下结果: 

[[ 0.95525807]
[ 0.9813394 ]]

通过这个例子,我们可以看出,对于相同的神经网络,不同的损失函数会对训练出来的模型产生重要的影响。 


引用:以上实例为《Tensorflow实战 Google深度学习框架》中提供。

tensorflow 自定义损失函数示例的更多相关文章

  1. 机器学习之路: tensorflow 自定义 损失函数

    git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ import tensorflow as tf ...

  2. 吴裕雄 python 神经网络——TensorFlow 自定义损失函数

    import tensorflow as tf from numpy.random import RandomState batch_size = 8 x = tf.placeholder(tf.fl ...

  3. Tensorflow 损失函数(loss function)及自定义损失函数(三)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/limiyudianzi/article ...

  4. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  5. tensflow自定义损失函数

    tensflow 不仅支持经典的损失函数,还可以优化任意的自定义损失函数. 预测商品销量时,如果预测值比真实销量大,商家损失的是生产商品的成本:如果预测值比真实值小,损失的则是商品的利润. 比如如果一 ...

  6. tensorflow2 自定义损失函数使用的隐藏坑

    Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制.当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型. ...

  7. 01_MUI之Boilerplate中:HTML5示例,动态组件,自定义字体示例,自定义字体示例,图标字体示例

     1安装HBuilder5.0.0,安装后的界面截图如下: 2 按照https://www.muicss.com/docs/v1/css-js/boilerplate-html中的说明,创建上图的 ...

  8. 深度学习之卷积神经网络CNN及tensorflow代码实现示例

    深度学习之卷积神经网络CNN及tensorflow代码实现示例 2017年05月01日 13:28:21 cxmscb 阅读数 151413更多 分类专栏: 机器学习 深度学习 机器学习   版权声明 ...

  9. Tensorflow%20实战Google深度学习框架 4.2.2 自定义损失函数源代码

    import os import tab import tensorflow as tf from numpy.random import RandomState print "hello ...

随机推荐

  1. Tomcat请求头过大

    今天开发反应Tomcat的请求头过大 <Connector port="8280" protocol="HTTP/1.1" connectionTimeo ...

  2. .net ef core 领域设计代码转换(上篇)

    一.前言 .net core 2.0正式版已经发布几个月了,经过研究,决定把项目转移过来,新手的话可以先看一些官方介绍 传送门:https://docs.microsoft.com/zh-cn/dot ...

  3. Mybatis整理_01

    Mybatis专题 Mybaits介绍 Mybatis是一个持久化框架,它有不同语言的版本,比如.NET和Java都有Mybatis对应的类库:它有大多数ORM框架都具有的功能,比如自定义的SQL语句 ...

  4. Visual Studio 生成DLL文件

    新建一个项目,在菜单栏中选择“项目”/“**属性”选项,该页面中将“输出类型”下拉列表中的选项选择为“类库”,然后重新生成一下该项目,或者在“Visual Studio 2008命令提示”中输入以下命 ...

  5. [转载] 红黑树(Red Black Tree)- 对于 JDK TreeMap的实现

    转载自http://blog.csdn.net/yangjun2/article/details/6542321 介绍另一种平衡二叉树:红黑树(Red Black Tree),红黑树由Rudolf B ...

  6. 前端面试题(4)iframe有哪些优点?iframe缺点是什么?

    优点: iframe能够原封不动的把嵌入的网页展现出来. 如果有多个网页引用iframe,那么你只需要修改iframe的内容,就可以实现调用的每一个页面内容的更改,方便快捷. 网页如果为了统一风格,头 ...

  7. Python 爬虫练习(一) 爬取国内代理ip

    简单的正则表达式练习,爬取代理 ip. 仅爬取前三页,用正则匹配过滤出 ip 地址和 端口,分别作为key.value 存入 validip 字典. 如果要确定代理 ip 是否真的可用,还需要再对代理 ...

  8. 深入浅出多线程——ReentrantLock (一)

    ReentrantLock是一个排它重入锁,与synchronized关键字语意类似,但比其功能更为强大.该类位于java.util.concurrent.locks包下,是Lock接口的实现类.基本 ...

  9. 分布式缓存之Ehcache与terracotta - Terracotta服务器概念篇

    1.介绍 Terracotta服务器为Terracotta产品提供分布式数据平台.Terracotta服务器集群被称为Terracotta服务器阵列(TSA).Terracotta服务器阵列可以从单个 ...

  10. 【Centos】解决设置JAVA_HOME不断失效问题

    问题还原: 我们都知道,要修改centos的全局配置,可以在/etc/profile这个文件里面修改,比如,我需要修改JAVA_HOME变量 ,那么一般来说我们只要在其中修改,source 一下就行了 ...