Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。

这里以tensorflow2官网中的例子来说明:

import numpy as np
import tensorflow as tf
from tensorflow import keras
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
class CustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics} # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"]) # Just use `fit` as usual model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7edf6dfd0>

这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?

答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。

def custom_mse(y_true, y_pred):
return tf.reduce_mean((y_true - y_pred)**2, axis=-1)
a_true = tf.constant([1., 1.5, 1.2])
a_pred = tf.constant([1., 2, 1.5])
custom_mse(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>
tf.losses.MSE(a_true, a_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.11333332>

以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。

my_model = CustomModel(inputs, outputs)
my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"])
my_model.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257

<tensorflow.python.keras.callbacks.History at 0x7ff7edeb7810>

我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。

正确运用自定义loss的姿势是什么呢?下面揭晓。

loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae") class MyCustomModel(keras.Model):
tf.random.set_seed(100)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = custom_mse(y, y_pred)
# loss += self.losses # Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Compute our own metrics
loss_tracker.update_state(loss)
mae_metric.update_state(y, y_pred)
return {"loss": loss_tracker.result(), "mae": mae_metric.result()} @property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
my_model_beta = MyCustomModel(inputs, outputs)
my_model_beta.compile(optimizer="adam") # Just use `fit` as usual my_model_beta.fit(x, y, epochs=1, shuffle=False)
32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257

<tensorflow.python.keras.callbacks.History at 0x7ff7eda3d810>

终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。

总结一下,当我们在模型中想用自定义的损失函数,不能直接传入fit函数,而是需要在train_step中手动传入,完成计算过程。

tensorflow2 自定义损失函数使用的隐藏坑的更多相关文章

  1. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  2. tensorflow 自定义损失函数示例

    这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变) 我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元. 如果我们 ...

  3. tensflow自定义损失函数

    tensflow 不仅支持经典的损失函数,还可以优化任意的自定义损失函数. 预测商品销量时,如果预测值比真实销量大,商家损失的是生产商品的成本:如果预测值比真实值小,损失的则是商品的利润. 比如如果一 ...

  4. 机器学习之路: tensorflow 自定义 损失函数

    git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ import tensorflow as tf ...

  5. Tensorflow 损失函数(loss function)及自定义损失函数(三)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/limiyudianzi/article ...

  6. SpringMVC自定义配置消息转换器踩坑总结

    问题描述 最近在开发时候碰到一个问题,springmvc页面向后台传数据的时候,通常我是这样处理的,在前台把数据打成一个json,在后台接口中使用@requestbody定义一个对象来接收,但是这次数 ...

  7. Fidder详解-工具简介(保存会话、decode解码、Repaly、自定义会话框、隐藏会话、会话排序)

    前言 本文会对Fidder这款工具的一些重要功能,进行详细讲解,带大家进入Fidder的世界,本文会让你明白,Fidder不仅是一个抓包分析工具,也是一个请求发送工具,更加可以当作为Mock Serv ...

  8. 隐藏软键盘(解决自定义Dialog中无法隐藏的问题)

    /** * Dialog中隐藏软键盘不管用 * @param activity */ public static void HideSoftKeyBoard(Activity activity){ t ...

  9. IOS 极光推送自定义通知遇到的一些坑

    主要方法: //自定义推送 - (void)networkDidReceiveMessage:(NSNotification *)notification { NSDictionary * userI ...

随机推荐

  1. Markdown个人常用语法

    Markdown实用格式 标题 # 标题一级 ## 标题二级 ### 标题三级 #### 标题四级 ##### 标题五级 ###### 标题六级 粗体.斜体和删除线 **加粗字体** *斜体字体* * ...

  2. UG_PS Parasolid相关的操作

    Open C UF_PS_ask_current_highest_tagUF_PS_ask_current_partitionUF_PS_ask_entity_partitionUF_PS_ask_j ...

  3. Java显式锁

    Java 显式锁. 一.显式锁 什么是显式锁? 由自己手动获取锁,然后手动释放的锁. 有了 synchronized(内置锁) 为什么还要 Lock(显示锁)? 使用 synchronized 关键字 ...

  4. MySQL 最佳实践 —— 高效插入数据

    当你需要在 MySQL 数据库中批量插入数百万条数据时,你就会意识到,逐条发送 INSERT 语句并不是一个可行的方法. MySQL 文档中有些值得一读的 INSERT 优化技巧. 在这篇文章里,我将 ...

  5. 如果你这么去理解HashMap就会发现它真的很简单

    Java中的HashMap相信大家都不陌生,也是大家编程时最常用的数据结构之一,各种面试题更是恨不得掘地三尺的去问HashMap.HashTable.ConcurrentHashMap,无论面试题多么 ...

  6. 重新整理 .net core 实践篇—————HttpClientFactory[三十二]

    前言 简单整理一下HttpClientFactory . 正文 这个HttpFactory 主要有下面的功能: 管理内部HttpMessageHandler 的生命周期,灵活应对资源问题和DNS刷新问 ...

  7. 在微信小程序中使用阿里图标库Iconfont

    首先想要使用图标,只用上图的五个iconfont相关文件就可以了.(下下来的文件iconfont.wxss开始是.css的后缀,手动改成.wxss就可以在小程序中使用) 然后在app.wxss中引入i ...

  8. 重新点亮linux 命令树————帮助命令[一]

    前言 重新整理一下linux的命令. 正文 这里首先介绍帮助命令. 帮助命令常用的有三个: man help info 那么就来看下这三个. man 第一个man,man不是男人的意思,而是manua ...

  9. 『心善渊』Selenium3.0基础 — 13、Selenium操作下拉菜单

    目录 1.使用Selenium中的Select类来处理下拉菜单(推荐) 2.下拉菜单对象的其他操作(了解) 3.通过元素二次定位方式操作下拉菜单(重点) (1)了解元素二次定位 (2)示例: 页面中的 ...

  10. 用python的matplotlib根据文件里面的数字画图像折线图

    思路:用open打开文件,再用a=filename.readlines()提取每行的数据作为列表的值,然后传递列表给matplotlib并引入对应库画出图像 代码实现:import matplotli ...