Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。

这里以tensorflow2官网中的例子来说明:

import numpy as np
import tensorflow as tf
from tensorflow import keras
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
class CustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics} # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"]) # Just use `fit` as usual model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7edf6dfd0>

这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?

答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。

def custom_mse(y_true, y_pred):
return tf.reduce_mean((y_true - y_pred)**2, axis=-1)
a_true = tf.constant([1., 1.5, 1.2])
a_pred = tf.constant([1., 2, 1.5])
custom_mse(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>
tf.losses.MSE(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>

以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。

my_model = CustomModel(inputs, outputs)
my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"])
my_model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257

<tensorflow.python.keras.callbacks.History at 0x7ff7edeb7810>

我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。

正确运用自定义loss的姿势是什么呢?下面揭晓。

loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae") class MyCustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = custom_mse(y, y_pred)
# loss += self.losses # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Compute our own metrics
loss_tracker.update_state(loss)
mae_metric.update_state(y, y_pred)
return {"loss": loss_tracker.result(), "mae": mae_metric.result()} @property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
my_model_beta = MyCustomModel(inputs, outputs)
my_model_beta.compile(optimizer="adam") # Just use `fit` as usual my_model_beta.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7eda3d810>

终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。

总结一下,当我们在模型中想用自定义的损失函数,不能直接传入fit函数,而是需要在train_step中手动传入,完成计算过程。

tensorflow2 自定义损失函数使用的隐藏坑的更多相关文章

  1. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  2. tensorflow 自定义损失函数示例

    这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变) 我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元. 如果我们 ...

  3. tensflow自定义损失函数

    tensflow 不仅支持经典的损失函数,还可以优化任意的自定义损失函数. 预测商品销量时,如果预测值比真实销量大,商家损失的是生产商品的成本:如果预测值比真实值小,损失的则是商品的利润. 比如如果一 ...

  4. 机器学习之路: tensorflow 自定义 损失函数

    git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ import tensorflow as tf ...

  5. Tensorflow 损失函数(loss function)及自定义损失函数(三)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/limiyudianzi/article ...

  6. SpringMVC自定义配置消息转换器踩坑总结

    问题描述 最近在开发时候碰到一个问题,springmvc页面向后台传数据的时候,通常我是这样处理的,在前台把数据打成一个json,在后台接口中使用@requestbody定义一个对象来接收,但是这次数 ...

  7. Fidder详解-工具简介(保存会话、decode解码、Repaly、自定义会话框、隐藏会话、会话排序)

    前言 本文会对Fidder这款工具的一些重要功能,进行详细讲解,带大家进入Fidder的世界,本文会让你明白,Fidder不仅是一个抓包分析工具,也是一个请求发送工具,更加可以当作为Mock Serv ...

  8. 隐藏软键盘(解决自定义Dialog中无法隐藏的问题)

    /** * Dialog中隐藏软键盘不管用 * @param activity */ public static void HideSoftKeyBoard(Activity activity){ t ...

  9. IOS 极光推送自定义通知遇到的一些坑

    主要方法: //自定义推送 - (void)networkDidReceiveMessage:(NSNotification *)notification { NSDictionary * userI ...

随机推荐

  1. Waymo的激光雷达计划:进展如何?

    Waymo的激光雷达计划:进展如何? Waymo's Lidar Plan: How's It Working out? 许多自动驾驶汽车(AV)开发商一直在热烈追求激光雷达技术,这一技术之所以重要, ...

  2. eclipse 新建项目不可选择Java Project 解决方法

    解决方法一: 鼠标点击file-new-other,弹出选项框,选中java project,点击next,接下来就是正常创建java protect的流程了,这个虽然也可以解决,但每次新建java项 ...

  3. IDEA骚技巧

    1. var 声明 2. null 判空 3. notnull 判非空 4. nn 判非空 5. for 遍历 6. fori 带索引的遍历 7. not 取反 8. if 条件判断 9. cast ...

  4. springboot静态资源映射规则

    一.所有/webjars/**的请求,都会去classpath:/META-INF/resources/webjars/下的目录去找资源. 二.访问/**,即访问任何资源,如果没有controller ...

  5. 工作流Activiti框架中的LDAP组件使用详解!实现对工作流目录信息的分布式访问及访问控制

    Activiti集成LDAP简介 企业在LDAP系统中保存了用户和群组信息,Activiti提供了一种解决方案,通过简单的配置就可以让activit连接LDAP 用法 要想在项目中集成LDAP,需要在 ...

  6. 【NX二次开发】Block UI 线型

    属性说明 常规         类型 描述     BlockID     String 控件ID     Enable     Logical 是否可操作     Group     Logical ...

  7. Spring MVC 到 Spring BOOT 的简化之路

    背景 Spring vs Spring MVC vs Spring Boot Spring FrameWork Spring 还能解决什么问题 Spring MVC 为什么需要Spring Boot ...

  8. BIM,PIM接入GIS 需要解决的关键技术问题

    随着技术发展,跨界融合已经不是新鲜事物,近两年BIM.PIM+GIS一张图的提出,给行业注入了一股清流. 为GIS行业发展带来了新的契机,同时也带来了一些新的挑战.面对挑战,本文将剖析BIM.PIM+ ...

  9. Vue(9)购物车练习

    购物车案例 经过一系列的学习,我们这里来练习一个购物车的案例   需求:使用vue写一个表单页面,页面上有购买的数量,点击按钮+或者-,可以增加或减少购物车的数量,数量最少不得少于0,点击移除按钮,会 ...

  10. Tkinter 吐槽之二:Event 事件在子元素中共享

    背景 最近想简单粗暴的用 Python 写一个 GUI 的小程序.因为 Tkinter 是 Python 自带的 GUI 解决方案,为了部署方便,就直接选择了 Tkinter. 本来觉得 GUI 发展 ...