SciTech-BigDataAIML-Tensorflow-Writing your own callbacks
Introduction
A powerful callback was used to customize the behavior of a Keras model during training, evaluation, or inference.
Examples include tf.keras.callbacks.TensorBoard to visualize training progress and results with TensorBoard, or tf.keras.callbacks.ModelCheckpoint to periodically save your model during training.
In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own.
We provide a few demos of simple callback applications to get you started.
Keras callbacks overview
All callbacks subclass the keras.callbacks.Callback class,
and override a set of methods called at various stages of training, testing, and predicting.
Callbacks are useful to get a view on internal states and statistics of the model during training.
You can pass a list of callbacks (as the **keyword argument callbacks**) to the following model methods:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()
An overview of callback methods
Global methods
on_(train|test|predict)**_begin**(self, logs=None): Called at the beginning of fit/evaluate/predict.on_(train|test|predict)**_end**(self, logs=None): Called at the end of fit/evaluate/predict.
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None): Called right before processing a batch during training/testing/predicting.on_(train|test|predict)_batch_end(self, batch, logs=None): Called at the end of a batched training/testing/predicting. Within this method, logs is a dict containing the metrics results.
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None): Called at the beginning of an epoch during training.on_epoch_end(self, epoch, logs=None): Called at the end of an epoch during training.
A example
Let's take a look at a concrete example. To get started, let's import tensorflow and define a Sequential Keras model:
import tensorflow as tf
import keras
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) =
tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
model = get_model()
model.fit(x_train, y_train, batch_size=128, epochs=1,
validation_split=0.5, verbose=0, callbacks=[CustomCallback()],)
res = model.evaluate(x_test, y_test, batch_size=128,
verbose=0, callbacks=[CustomCallback()])
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Usage of logs dict
The logs dict contains the loss value, and all the metrics at the end of a batch or epoch.
Example includes the loss and mean absolute error.
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print("Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]))
def on_test_batch_end(self, batch, logs=None):
print("Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"]))
def on_epoch_end(self, epoch, logs=None):
print("The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.".format(epoch, logs["loss"], logs["mean_absolute_error"]) )
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
Usage of self.model attribute
In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self.model.
Here are a few of the things you can do with self.model in a callback:
- Set self.model.stop_training = True to immediately interrupt training.
- Mutate hyperparameters of the optimizer (available as self.model.optimizer), such as self.model.optimizer.learning_rate.
- Save the model at period intervals.
- Record the output of model.predict() on a few test samples at the end of each epoch, to use as a sanity check during training.
- Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
- etc.
Let's see this in action in a couple of examples.
Examples of Keras callback applications
Early stopping at minimum loss
This first example shows the creation of a Callback that stops training when the minimum of loss has been reached,
by setting the attribute self.model.stop_training (boolean).
Optionally, you can provide an argument patience to specify how many epochs we should wait before stopping after having reached a local minimum.
tf.keras.callbacks.EarlyStopping provides a more complete and general implementation.
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super().__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(x_train, y_train,
batch_size=64, steps_per_epoch=5, epochs=30,
verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
Learning rate scheduling
In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.
See callbacks.LearningRateScheduler for a more general implementations.
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super().__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
# (epoch to start, learning rate) tuples
LR_SCHEDULE = [ (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001), ]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(x_train, y_train,
batch_size=64, steps_per_epoch=5, epochs=15,
verbose=0, callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
SciTech-BigDataAIML-Tensorflow-Writing your own callbacks的更多相关文章
- 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化
TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了.tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方. ...
- TensorFlow中使用tf.keras.callbacks.EarlyStopping防止训练过拟合
TensorFlow tf.keras.callbacks.EarlyStopping 当模型训练次数epoch设置到100甚至更大时,如果模型的效果没有进一步提升,那么训练可以提前停止,继续训练很可 ...
- [源码解析] TensorFlow 分布式之 ParameterServerStrategy V2
[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2 目录 [源码解析] TensorFlow 分布式之 ParameterServerStrategy ...
- tensorflow添加自定义的auc计算operator
tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearn的auc计算函数或者自己写一个 但是会在python执行,这里希望在graph中也就是c++端执行这 ...
- (转)The Road to TensorFlow
Stephen Smith's Blog All things Sage 300… The Road to TensorFlow – Part 7: Finally Some Code leave a ...
- Ubuntu 16.04 + CUDA 8.0 + cuDNN v5.1 + TensorFlow(GPU support)安装配置详解
随着图像识别和深度学习领域的迅猛发展,GPU时代即将来临.由于GPU处理深度学习算法的高效性,使得配置一台搭载有GPU的服务器变得尤为必要. 本文主要介绍在Ubuntu 16.04环境下如何配置Ten ...
- 【深度学习】keras + tensorflow 实现猫和狗图像分类
本文主要是使用[监督学习]实现一个图像分类器,目的是识别图片是猫还是狗. 从[数据预处理]到 [图片预测]实现一个完整的流程, 当然这个分类在 Kaggle 上已经有人用[迁移学习](VGG,Resn ...
- Tensorflow简单CNN实现
觉得有用的话,欢迎一起讨论相互学习~Follow Me 少说废话多写代码~ """转换图像数据格式时需要将它们的颜色空间变为灰度空间,将图像尺寸修改为同一尺寸,并将标签依 ...
- Tensorflow模型加载与保存、Tensorboard简单使用
先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...
- 学习TensorFlow,TensorBoard可视化网络结构和参数
在学习深度网络框架的过程中,我们发现一个问题,就是如何输出各层网络参数,用于更好地理解,调试和优化网络?针对这个问题,TensorFlow开发了一个特别有用的可视化工具包:TensorBoard,既可 ...
随机推荐
- 《Deep Learning Inference on Embedded Devices: Fixed-Point vs Posit》(一)
After the success of performing deep learning inference by using an 8-bit precision representation o ...
- Java日期格式化中的“YYYY”陷阱:为什么跨年周会让你的年份突然+1?.md
结论先行 在Java中使用 YYYY-MM-dd 格式化日期时,若日期所在的周跨年,年份可能会被错误计算为下一年(如2021年12月26日显示为2022年).而使用 yyyy-MM-dd 会始终返回正 ...
- XXL-TOOL v1.4.0 发布 | Java工具类库
Release Notes 1.[新增]JsonRpc模块:一个轻量级.跨语言远程过程调用实现,基于json.http实现(从XXL-JOB底层通讯组件提炼抽象). 2.[新增]Concurrent模 ...
- 自己做的linux动态壁纸软件
自己做的linux动态壁纸软件 https://github.com/dependon/fantascene-dynamic-wallpaper
- IntelliJ IDEA 源文件提示 cannot resolve method 或者 Cannot find declaration to go to
问题描述:IntelliJ IDEA 在源文件中提示 Cannot resolve method,但是项目可以正常编译运行,提示异常的类明明存在且没有任何异常.尝试使用ctrl+鼠标左键进入该类时,提 ...
- 2023电赛E题代码
openmv: import sensor, image, time from pyb import LED, UART #import lcd import json, ustruct class ...
- Spring扩展接口-CommandLineRunner、ApplicationRunner
.markdown-body { line-height: 1.75; font-weight: 400; font-size: 16px; overflow-x: hidden; color: rg ...
- Mysql基线核查
查看版本信息 select @@version 查看默认创建的测试库和测试用户 show databases like "test%"; select * from mysql.u ...
- Arduino从零开始的高手之路0——引言:Arduino是世界上最好的开发板!
开篇先比比: 虽然我们的老会长一直强调Arduino是个很菜的东西,但是的确是嵌入式入门的不二法宝啊. 现在其实我已经学了stm32了,arduino自认为比较精通了,但是其实实践上手的机会还是很少 ...
- SAP的JOB 的step执行顺序
为了验证一个JOB多个STEP的情况,STEP的执行顺序问题,做了一个小测试. 测试数据: 测试程序1,做step1: 测试程序2,做step2 测试JOB情况 STEP1正常执行 STEP2执行结果 ...