SciTech-BigDataAIML-Tensorflow-Writing your own callbacks
Introduction
A powerful callback was used to customize the behavior of a Keras model during training, evaluation, or inference.
Examples include tf.keras.callbacks.TensorBoard to visualize training progress and results with TensorBoard, or tf.keras.callbacks.ModelCheckpoint to periodically save your model during training.
In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own.
We provide a few demos of simple callback applications to get you started.
Keras callbacks overview
All callbacks subclass the keras.callbacks.Callback class,
and override a set of methods called at various stages of training, testing, and predicting.
Callbacks are useful to get a view on internal states and statistics of the model during training.
You can pass a list of callbacks (as the **keyword argument callbacks**) to the following model methods:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
An overview of callback methods
Global methods
on_(train|test|predict)**_begin**(self, logs=None): Called at the beginning of fit/evaluate/predict.on_(train|test|predict)**_end**(self, logs=None): Called at the end of fit/evaluate/predict.
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None): Called right before processing a batch during training/testing/predicting.on_(train|test|predict)_batch_end(self, batch, logs=None): Called at the end of a batched training/testing/predicting. Within this method, logs is a dict containing the metrics results.
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None): Called at the beginning of an epoch during training.on_epoch_end(self, epoch, logs=None): Called at the end of an epoch during training.
A example
Let's take a look at a concrete example. To get started, let's import tensorflow and define a Sequential Keras model:
import tensorflow as tf
import keras
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) =
tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
model = get_model()
model.fit(x_train, y_train, batch_size=128, epochs=1,
validation_split=0.5, verbose=0, callbacks=[CustomCallback()],)
res = model.evaluate(x_test, y_test, batch_size=128,
verbose=0, callbacks=[CustomCallback()])
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Usage of logs dict
The logs dict contains the loss value, and all the metrics at the end of a batch or epoch.
Example includes the loss and mean absolute error.
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print("Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]))
def on_test_batch_end(self, batch, logs=None):
print("Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]))
def on_epoch_end(self, epoch, logs=None):
print("The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.".format(epoch, logs["loss"], logs["mean_absolute_error"]) )
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
Usage of self.model attribute
In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self.model.
Here are a few of the things you can do with self.model in a callback:
- Set self.model.stop_training = True to immediately interrupt training.
- Mutate hyperparameters of the optimizer (available as self.model.optimizer), such as self.model.optimizer.learning_rate.
- Save the model at period intervals.
- Record the output of model.predict() on a few test samples at the end of each epoch, to use as a sanity check during training.
- Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
- etc.
Let's see this in action in a couple of examples.
Examples of Keras callback applications
Early stopping at minimum loss
This first example shows the creation of a Callback that stops training when the minimum of loss has been reached,
by setting the attribute self.model.stop_training (boolean).
Optionally, you can provide an argument patience to specify how many epochs we should wait before stopping after having reached a local minimum.
tf.keras.callbacks.EarlyStopping provides a more complete and general implementation.
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super().__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(x_train, y_train,
batch_size=64, steps_per_epoch=5, epochs=30,
verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
Learning rate scheduling
In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.
See callbacks.LearningRateScheduler for a more general implementations.
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super().__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
# (epoch to start, learning rate) tuples
LR_SCHEDULE = [ (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001), ]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(x_train, y_train,
batch_size=64, steps_per_epoch=5, epochs=15,
verbose=0, callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
SciTech-BigDataAIML-Tensorflow-Writing your own callbacks的更多相关文章
- 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化
TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了.tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方. ...
- TensorFlow中使用tf.keras.callbacks.EarlyStopping防止训练过拟合
TensorFlow tf.keras.callbacks.EarlyStopping 当模型训练次数epoch设置到100甚至更大时,如果模型的效果没有进一步提升,那么训练可以提前停止,继续训练很可 ...
- [源码解析] TensorFlow 分布式之 ParameterServerStrategy V2
[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2 目录 [源码解析] TensorFlow 分布式之 ParameterServerStrategy ...
- tensorflow添加自定义的auc计算operator
tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearn的auc计算函数或者自己写一个 但是会在python执行,这里希望在graph中也就是c++端执行这 ...
- (转)The Road to TensorFlow
Stephen Smith's Blog All things Sage 300… The Road to TensorFlow – Part 7: Finally Some Code leave a ...
- Ubuntu 16.04 + CUDA 8.0 + cuDNN v5.1 + TensorFlow(GPU support)安装配置详解
随着图像识别和深度学习领域的迅猛发展,GPU时代即将来临.由于GPU处理深度学习算法的高效性,使得配置一台搭载有GPU的服务器变得尤为必要. 本文主要介绍在Ubuntu 16.04环境下如何配置Ten ...
- 【深度学习】keras + tensorflow 实现猫和狗图像分类
本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...
- Tensorflow简单CNN实现
觉得有用的话,欢迎一起讨论相互学习~Follow Me 少说废话多写代码~ """转换图像数据格式时需要将它们的颜色空间变为灰度空间,将图像尺寸修改为同一尺寸,并将标签依 ...
- Tensorflow模型加载与保存、Tensorboard简单使用
先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...
- 学习TensorFlow,TensorBoard可视化网络结构和参数
在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...
随机推荐
- RabbitMq在win10上的安装、用户管理及控制台Demo
思路: 安装elang--设置elang的环境变量--安装erlang版本对应的rabbitmq--设置rabbitmq的环境变量--安装rabbitmq的可视化管理插件 相关链接: RabbitMQ ...
- 企业级开源CMS新标杆,三分钟搭建多语言官网!
HuoCMS是基于ThinkPHP6和Vue3研发的现代化内容管理系统,专为中小企业及开发者打造全场景数字化解决方案.系统采用MIT开源协议,支持多语言.多终端适配,内置可视化编辑器与SEO优化体系, ...
- 基于ThinkPHP5知识付费系统AntPayCMS
历时6个月开发基于ThinkPHP5.1知识付费系统AntPayCMS,自己作IT开发已经10年,一直想自己开发自己的系统,虽然看网上也有很多知识付费类的网站的,但基于TP基本很少,而且自己也一直想做 ...
- Flutter适配HarmonyOS 5开发知识地图
还在为Flutter适配HarmonyOS 5头疼?这份知识地图,用实战解析+高频避坑指南,帮你快速打通跨平台开发任督二脉! ▌为什么这份资源值得你收藏? 分层进阶:从环境搭建→插件开发→性能优化,匹 ...
- 深入浅出了解生成模型-1:GAN模型原理以及代码实战
更加好排版:https://www.big-yellow-j.top/posts/2025/05/08/GAN.html 日常使用比较多的生成模型比如GPT/Qwen等这些大多都是"文生文& ...
- vue3适配移动端的登录实现
<script lang="ts" setup> import { ref } from 'vue' const PHONE_NUMBER_REGEX = /^1[0- ...
- Ubuntu16.04安装全记录(手工分区版)
记录我在为SSD+机械硬盘的笔记本上安装Ubuntu16.04的全过程,主要是介绍手工分区需要注意的细节. 一.制作Ubuntu安装U盘 工具:UltraISO 镜像:http://releases. ...
- Vue 学习笔记 [Part 2]
作者:故事我忘了¢个人微信公众号:程序猿的月光宝盒 目录 一. 计算属性 1.1. 计算属性的本质 1.2. 计算属性和methods对比 〇.ES6补充 0.1. let/var 0.2 const ...
- 【中英】【吴恩达课后测验】Course 4 -卷积神经网络 - 第四周测验
[中英][吴恩达课后测验]Course 4 -卷积神经网络 - 第四周测验 - 特殊应用:人脸识别和神经风格转换 上一篇:[课程4 - 第三周编程作业]※※※※※ [回到目录]※※※※※下一篇:[待撰 ...
- Python基础—初识函数(二)
1.给函数参数增加元信息 写好一个函数,然后想为这个函数的参数增加一些额外的信息,这样的话其他使用者就能清楚的知道这个函数应该怎么使用. 使用函数参数注解是一个很好的办法,它能提示程序员应该怎样正确使 ...