1. 什么是Hook

经常会听到钩子函数(hook function)这个概念,最近在看目标检测开源框架mmdetection,里面也出现大量Hook的编程方式,那到底什么是hook?hook的作用是什么?

  • what is hook ?钩子hook,顾名思义,可以理解是一个挂钩,作用是有需要的时候挂一个东西上去。具体的解释是:钩子函数是把我们自己实现的hook函数在某一时刻挂接到目标挂载点上。

  • hook函数的作用 举个例子,hook的概念在windows桌面软件开发很常见,特别是各种事件触发的机制; 比如C++的MFC程序中,要监听鼠标左键按下的时间,MFC提供了一个onLeftKeyDown的钩子函数。很显然,MFC框架并没有为我们实现onLeftKeyDown具体的操作,只是为我们提供一个钩子,当我们需要处理的时候,只要去重写这个函数,把我们需要操作挂载在这个钩子里,如果我们不挂载,MFC事件触发机制中执行的就是空的操作。

从上面可知

  • hook函数是程序中预定义好的函数,这个函数处于原有程序流程当中(暴露一个钩子出来)

  • 我们需要再在有流程中钩子定义的函数块中实现某个具体的细节,需要把我们的实现,挂接或者注册(register)到钩子里,使得hook函数对目标可用

  • hook 是一种编程机制,和具体的语言没有直接的关系

  • 如果从设计模式上看,hook模式是模板方法的扩展

  • 钩子只有注册的时候,才会使用,所以原有程序的流程中,没有注册或挂载时,执行的是空(即没有执行任何操作)

本文用python来解释hook的实现方式,并展示在开源项目中hook的应用案例。hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。

2. hook实现例子

据我所知,hook函数最常使用在某种流程处理当中。这个流程往往有很多步骤。hook函数常常挂载在这些步骤中,为增加额外的一些操作,提供灵活性。

下面举一个简单的例子,这个例子的目的是实现一个通用往队列中插入内容的功能。流程步骤有2个

  • 需要再插入队列前,对数据进行筛选 input_filter_fn

  • 插入队列 insert_queue

  1. class ContentStash(object):
  2.     """
  3.     content stash for online operation
  4.     pipeline is
  5.     1. input_filter: filter some contents, no use to user
  6.     2. insert_queue(redis or other broker): insert useful content to queue
  7.     """
  8.  
  9.     def __init__(self):
  10.         self.input_filter_fn = None
  11.         self.broker = []
  12.  
  13.     def register_input_filter_hook(self, input_filter_fn):
  14.         """
  15.         register input filter function, parameter is content dict
  16.         Args:
  17.             input_filter_fn: input filter function
  18.  
  19.         Returns:
  20.  
  21.         """
  22.         self.input_filter_fn = input_filter_fn
  23.  
  24.     def insert_queue(self, content):
  25.         """
  26.         insert content to queue
  27.         Args:
  28.             content: dict
  29.  
  30.         Returns:
  31.  
  32.         """
  33.         self.broker.append(content)
  34.  
  35.     def input_pipeline(self, content, use=False):
  36.         """
  37.         pipeline of input for content stash
  38.         Args:
  39.             use: is use, defaul False
  40.             content: dict
  41.  
  42.         Returns:
  43.  
  44.         """
  45.         if not use:
  46.             return
  47.  
  48.         # input filter
  49.         if self.input_filter_fn:
  50.             _filter = self.input_filter_fn(content)
  51.             
  52.         # insert to queue
  53.         if not _filter:
  54.             self.insert_queue(content)
  55.  
  56.  
  57.  
  58. # test
  59. ## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列
  60. def input_filter_hook(content):
  61.     """
  62.     test input filter hook
  63.     Args:
  64.         content: dict
  65.  
  66.     Returns: None or content
  67.  
  68.     """
  69.     if content.get('time') is None:
  70.         return
  71.     else:
  72.         return content
  73.  
  74.  
  75. # 原有程序
  76. content = {'filename': 'test.jpg', 'b64_file': "#test", 'data': {"result": "cat", "probility": 0.9}}
  77. content_stash = ContentStash('audit', work_dir='')
  78.  
  79. # 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content
  80. content_stash.register_input_filter_hook(input_filter_hook)
  81.  
  82. # 执行流程
  83. content_stash.input_pipeline(content)
  84.  

3. hook在开源框架中的应用

3.1 keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练

  • 训练一个epoch前

  • 训练一个batch前

  • 训练一个batch后

  • 训练一个epoch后

  • 评估验证集

  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

  1. @keras_export('keras.callbacks.Callback')
  2. class Callback(object):
  3.   """Abstract base class used to build new callbacks.
  4.  
  5.   Attributes:
  6.       params: Dict. Training parameters
  7.           (eg. verbosity, batch size, number of epochs...).
  8.       model: Instance of `keras.models.Model`.
  9.           Reference of the model being trained.
  10.  
  11.   The `logs` dictionary that callback methods
  12.   take as argument will contain keys for quantities relevant to
  13.   the current batch or epoch (see method-specific docstrings).
  14.   """
  15.  
  16.   def __init__(self):
  17.     self.validation_data = None  # pylint: disable=g-missing-from-attributes
  18.     self.model = None
  19.     # Whether this Callback should only run on the chief worker in a
  20.     # Multi-Worker setting.
  21.     # TODO(omalleyt): Make this attr public once solution is stable.
  22.     self._chief_worker_only = None
  23.     self._supports_tf_logs = False
  24.  
  25.   def set_params(self, params):
  26.     self.params = params
  27.  
  28.   def set_model(self, model):
  29.     self.model = model
  30.  
  31.   @doc_controls.for_subclass_implementers
  32.   @generic_utils.default
  33.   def on_batch_begin(self, batch, logs=None):
  34.     """A backwards compatibility alias for `on_train_batch_begin`."""
  35.  
  36.   @doc_controls.for_subclass_implementers
  37.   @generic_utils.default
  38.   def on_batch_end(self, batch, logs=None):
  39.     """A backwards compatibility alias for `on_train_batch_end`."""
  40.  
  41.   @doc_controls.for_subclass_implementers
  42.   def on_epoch_begin(self, epoch, logs=None):
  43.     """Called at the start of an epoch.
  44.  
  45.     Subclasses should override for any actions to run. This function should only
  46.     be called during TRAIN mode.
  47.  
  48.     Arguments:
  49.         epoch: Integer, index of epoch.
  50.         logs: Dict. Currently no data is passed to this argument for this method
  51.           but that may change in the future.
  52.     """
  53.  
  54.   @doc_controls.for_subclass_implementers
  55.   def on_epoch_end(self, epoch, logs=None):
  56.     """Called at the end of an epoch.
  57.  
  58.     Subclasses should override for any actions to run. This function should only
  59.     be called during TRAIN mode.
  60.  
  61.     Arguments:
  62.         epoch: Integer, index of epoch.
  63.         logs: Dict, metric results for this training epoch, and for the
  64.           validation epoch if validation is performed. Validation result keys
  65.           are prefixed with `val_`.
  66.     """
  67.  
  68.   @doc_controls.for_subclass_implementers
  69.   @generic_utils.default
  70.   def on_train_batch_begin(self, batch, logs=None):
  71.     """Called at the beginning of a training batch in `fit` methods.
  72.  
  73.     Subclasses should override for any actions to run.
  74.  
  75.     Arguments:
  76.         batch: Integer, index of batch within the current epoch.
  77.         logs: Dict, contains the return value of `model.train_step`. Typically,
  78.           the values of the `Model`'s metrics are returned.  Example:
  79.           `{'loss': 0.2, 'accuracy': 0.7}`.
  80.     """
  81.     # For backwards compatibility.
  82.     self.on_batch_begin(batch, logs=logs)
  83.  
  84.   @doc_controls.for_subclass_implementers
  85.   @generic_utils.default
  86.   def on_train_batch_end(self, batch, logs=None):
  87.     """Called at the end of a training batch in `fit` methods.
  88.  
  89.     Subclasses should override for any actions to run.
  90.  
  91.     Arguments:
  92.         batch: Integer, index of batch within the current epoch.
  93.         logs: Dict. Aggregated metric results up until this batch.
  94.     """
  95.     # For backwards compatibility.
  96.     self.on_batch_end(batch, logs=logs)
  97.  
  98.   @doc_controls.for_subclass_implementers
  99.   @generic_utils.default
  100.   def on_test_batch_begin(self, batch, logs=None):
  101.     """Called at the beginning of a batch in `evaluate` methods.
  102.  
  103.     Also called at the beginning of a validation batch in the `fit`
  104.     methods, if validation data is provided.
  105.  
  106.     Subclasses should override for any actions to run.
  107.  
  108.     Arguments:
  109.         batch: Integer, index of batch within the current epoch.
  110.         logs: Dict, contains the return value of `model.test_step`. Typically,
  111.           the values of the `Model`'s metrics are returned.  Example:
  112.           `{'loss': 0.2, 'accuracy': 0.7}`.
  113.     """
  114.  
  115.   @doc_controls.for_subclass_implementers
  116.   @generic_utils.default
  117.   def on_test_batch_end(self, batch, logs=None):
  118.     """Called at the end of a batch in `evaluate` methods.
  119.  
  120.     Also called at the end of a validation batch in the `fit`
  121.     methods, if validation data is provided.
  122.  
  123.     Subclasses should override for any actions to run.
  124.  
  125.     Arguments:
  126.         batch: Integer, index of batch within the current epoch.
  127.         logs: Dict. Aggregated metric results up until this batch.
  128.     """
  129.  
  130.   @doc_controls.for_subclass_implementers
  131.   @generic_utils.default
  132.   def on_predict_batch_begin(self, batch, logs=None):
  133.     """Called at the beginning of a batch in `predict` methods.
  134.  
  135.     Subclasses should override for any actions to run.
  136.  
  137.     Arguments:
  138.         batch: Integer, index of batch within the current epoch.
  139.         logs: Dict, contains the return value of `model.predict_step`,
  140.           it typically returns a dict with a key 'outputs' containing
  141.           the model's outputs.
  142.     """
  143.  
  144.   @doc_controls.for_subclass_implementers
  145.   @generic_utils.default
  146.   def on_predict_batch_end(self, batch, logs=None):
  147.     """Called at the end of a batch in `predict` methods.
  148.  
  149.     Subclasses should override for any actions to run.
  150.  
  151.     Arguments:
  152.         batch: Integer, index of batch within the current epoch.
  153.         logs: Dict. Aggregated metric results up until this batch.
  154.     """
  155.  
  156.   @doc_controls.for_subclass_implementers
  157.   def on_train_begin(self, logs=None):
  158.     """Called at the beginning of training.
  159.  
  160.     Subclasses should override for any actions to run.
  161.  
  162.     Arguments:
  163.         logs: Dict. Currently no data is passed to this argument for this method
  164.           but that may change in the future.
  165.     """
  166.  
  167.   @doc_controls.for_subclass_implementers
  168.   def on_train_end(self, logs=None):
  169.     """Called at the end of training.
  170.  
  171.     Subclasses should override for any actions to run.
  172.  
  173.     Arguments:
  174.         logs: Dict. Currently the output of the last call to `on_epoch_end()`
  175.           is passed to this argument for this method but that may change in
  176.           the future.
  177.     """
  178.  
  179.   @doc_controls.for_subclass_implementers
  180.   def on_test_begin(self, logs=None):
  181.     """Called at the beginning of evaluation or validation.
  182.  
  183.     Subclasses should override for any actions to run.
  184.  
  185.     Arguments:
  186.         logs: Dict. Currently no data is passed to this argument for this method
  187.           but that may change in the future.
  188.     """
  189.  
  190.   @doc_controls.for_subclass_implementers
  191.   def on_test_end(self, logs=None):
  192.     """Called at the end of evaluation or validation.
  193.  
  194.     Subclasses should override for any actions to run.
  195.  
  196.     Arguments:
  197.         logs: Dict. Currently the output of the last call to
  198.           `on_test_batch_end()` is passed to this argument for this method
  199.           but that may change in the future.
  200.     """
  201.  
  202.   @doc_controls.for_subclass_implementers
  203.   def on_predict_begin(self, logs=None):
  204.     """Called at the beginning of prediction.
  205.  
  206.     Subclasses should override for any actions to run.
  207.  
  208.     Arguments:
  209.         logs: Dict. Currently no data is passed to this argument for this method
  210.           but that may change in the future.
  211.     """
  212.  
  213.   @doc_controls.for_subclass_implementers
  214.   def on_predict_end(self, logs=None):
  215.     """Called at the end of prediction.
  216.  
  217.     Subclasses should override for any actions to run.
  218.  
  219.     Arguments:
  220.         logs: Dict. Currently no data is passed to this argument for this method
  221.           but that may change in the future.
  222.     """
  223.  
  224.   def _implements_train_batch_hooks(self):
  225.     """Determines if this Callback should be called for each train batch."""
  226.     return (not generic_utils.is_default(self.on_batch_begin) or
  227.             not generic_utils.is_default(self.on_batch_end) or
  228.             not generic_utils.is_default(self.on_train_batch_begin) or
  229.             not generic_utils.is_default(self.on_train_batch_end))

这些钩子的原始程序是在模型训练流程中的

keras源码位置: tensorflow\python\keras\engine\training.py

部分摘录如下(## I am hook):

  1. # Container that configures and calls `tf.keras.Callback`s.
  2.       if not isinstance(callbacks, callbacks_module.CallbackList):
  3.         callbacks = callbacks_module.CallbackList(
  4.             callbacks,
  5.             add_history=True,
  6.             add_progbar=verbose != 0,
  7.             model=self,
  8.             verbose=verbose,
  9.             epochs=epochs,
  10.             steps=data_handler.inferred_steps)
  11.  
  12.       ## I am hook
  13.       callbacks.on_train_begin()
  14.       training_logs = None
  15.       # Handle fault-tolerance for multi-worker.
  16.       # TODO(omalleyt): Fix the ordering issues that mean this has to
  17.       # happen after `callbacks.on_train_begin`.
  18.       data_handler._initial_epoch = (  # pylint: disable=protected-access
  19.           self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
  20.       for epoch, iterator in data_handler.enumerate_epochs():
  21.         self.reset_metrics()
  22.         callbacks.on_epoch_begin(epoch)
  23.         with data_handler.catch_stop_iteration():
  24.           for step in data_handler.steps():
  25.             with trace.Trace(
  26.                 'TraceContext',
  27.                 graph_type='train',
  28.                 epoch_num=epoch,
  29.                 step_num=step,
  30.                 batch_size=batch_size):
  31.               ## I am hook
  32.               callbacks.on_train_batch_begin(step)
  33.               tmp_logs = train_function(iterator)
  34.               if data_handler.should_sync:
  35.                 context.async_wait()
  36.               logs = tmp_logs  # No error, now safe to assign to logs.
  37.               end_step = step + data_handler.step_increment
  38.               callbacks.on_train_batch_end(end_step, logs)
  39.         epoch_logs = copy.copy(logs)
  40.  
  41.         # Run validation.
  42.  
  43.         ## I am hook
  44.         callbacks.on_epoch_end(epoch, epoch_logs)

3.2 mmdetection

mmdetection是一个目标检测的开源框架,集成了许多不同的目标检测深度学习算法(pytorch版),如faster-rcnn, fpn, retianet等。里面也大量使用了hook,暴露给应用实现流程中具体部分。

详见https://github.com/open-mmlab/mmdetection

这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py

  1. def train_detector(model,
  2.                    dataset,
  3.                    cfg,
  4.                    distributed=False,
  5.                    validate=False,
  6.                    timestamp=None,
  7.                    meta=None):
  8.     logger = get_root_logger(cfg.log_level)
  9.  
  10.     # prepare data loaders
  11.  
  12.     # put model on gpus
  13.  
  14.     # build runner
  15.     optimizer = build_optimizer(model, cfg.optimizer)
  16.     runner = EpochBasedRunner(
  17.         model,
  18.         optimizer=optimizer,
  19.         work_dir=cfg.work_dir,
  20.         logger=logger,
  21.         meta=meta)
  22.     # an ugly workaround to make .log and .log.json filenames the same
  23.     runner.timestamp = timestamp
  24.  
  25.     # fp16 setting
  26.     # register hooks
  27.     runner.register_training_hooks(cfg.lr_config, optimizer_config,
  28.                                    cfg.checkpoint_config, cfg.log_config,
  29.                                    cfg.get('momentum_config', None))
  30.     if distributed:
  31.         runner.register_hook(DistSamplerSeedHook())
  32.  
  33.     # register eval hooks
  34.     if validate:
  35.         # Support batch_size > 1 in validation
  36.         eval_cfg = cfg.get('evaluation', {})
  37.         eval_hook = DistEvalHook if distributed else EvalHook
  38.         runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
  39.  
  40.     # user-defined hooks
  41.     if cfg.get('custom_hooks', None):
  42.         custom_hooks = cfg.custom_hooks
  43.         assert isinstance(custom_hooks, list), \
  44.             f'custom_hooks expect list type, but got {type(custom_hooks)}'
  45.         for hook_cfg in cfg.custom_hooks:
  46.             assert isinstance(hook_cfg, dict), \
  47.                 'Each item in custom_hooks expects dict type, but got ' \
  48.                 f'{type(hook_cfg)}'
  49.             hook_cfg = hook_cfg.copy()
  50.             priority = hook_cfg.pop('priority', 'NORMAL')
  51.             hook = build_from_cfg(hook_cfg, HOOKS)
  52.             runner.register_hook(hook, priority=priority)

4. 总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现

  • 挂载或者注册时, 流程执行就会执行这个钩子函数

  • 回调函数和hook函数功能上是一致的

  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理

想要获取更多Python学习资料可以加
QQ:2955637827私聊
或加Q群630390733
大家一起来学习讨论吧!

让你轻松掌握 Python 中的 Hook 钩子函数的更多相关文章

  1. 轻松理解python中的闭包和装饰器 (下)

    在 上篇 我们讲了python将函数做为返回值和闭包的概念,下面我们继续讲解函数做参数和装饰器,这个功能相当方便实用,可以极大地简化代码,就让我们go on吧! 能接受函数做参数的函数我们称之为高阶函 ...

  2. Python 函数式编程 & Python中的高阶函数map reduce filter 和sorted

    1. 函数式编程 1)概念 函数式编程是一种编程模型,他将计算机运算看做是数学中函数的计算,并且避免了状态以及变量的概念.wiki 我们知道,对象是面向对象的第一型,那么函数式编程也是一样,函数是函数 ...

  3. Python中的高阶函数与匿名函数

    Python中的高阶函数与匿名函数 高阶函数 高阶函数就是把函数当做参数传递的一种函数.其与C#中的委托有点相似,个人认为. def add(x,y,f): return f( x)+ f( y) p ...

  4. python中enumerate()函数用法

    python中enumerate()函数用法 先出一个题目:1.有一 list= [1, 2, 3, 4, 5, 6]  请打印输出:0, 1 1, 2 2, 3 3, 4 4, 5 5, 6 打印输 ...

  5. Python中str()与repr()函数的区别——repr() 的输出追求明确性,除了对象内容,还需要展示出对象的数据类型信息,适合开发和调试阶段使用

    Python中str()与repr()函数的区别 from:https://www.jianshu.com/p/2a41315ca47e 在 Python 中要将某一类型的变量或者常量转换为字符串对象 ...

  6. Python中sort和sorted函数代码解析

    Python中sort和sorted函数代码解析 本文研究的主要是Python中sort和sorted函数的相关内容,具体如下. 一.sort函数 sort函数是序列的内部函数 函数原型: L.sor ...

  7. Python中进制转换函数的使用

    Python中进制转换函数的使用 关于Python中几个进制转换的函数使用方法,做一个简单的使用方法的介绍,我们常用的进制转换函数常用的就是int()(其他进制转换到十进制).bin()(十进制转换到 ...

  8. 轻松理解python中的闭包和装饰器(上)

    继面向对象编程之后函数式编程逐渐火起来了,在python中也同样支持函数式编程,我们平时使用的map, reduce, filter等都是函数式编程的例子.在函数式编程中,函数也作为一个变量存在,对应 ...

  9. Python中的内置函数

    2.1 Built-in Functions The Python interpreter has a number of functions built into it that are alway ...

随机推荐

  1. 【PYTEST】第一章常用命令

    pytest入门 安装pytest 运行pytest pytest常用命令 1. 安装pytest pip install pytest 2. 运行pytest 2.1 pytest默认搜索当前目录下 ...

  2. 安装git和lsof

    yum install git yum install lsof 查看80端口 lsof -i:80

  3. mysql 分组查询

    mysql 分组查询 获取id最大的一条 (1)分组查询获取最大id SELECT MAX(id) as maxId FROM `d_table` GROUP BY `parent_id` ; (2) ...

  4. mycat分片及主从(二)

    一.mycat分片规则 经过上一篇幅讲解,应该很清楚分片规则配置文件rule.xml位于$MYCAT_HOME/conf目录,它定义了所有拆分表的规则.在使用过程中可以灵活使用不同的分片算法,或者对同 ...

  5. NameServer路由删除

    NameServer会每隔10s扫描brokerLiveTable状态表,如果BrokerLive的lastUpdateTimestamp的时间戳距当前时间超过120s,则认为Broker失效,移除改 ...

  6. fist-第七天冲刺随笔

    这个作业属于哪个课程 https://edu.cnblogs.com/campus/fzzcxy/2018SE1 这个作业要求在哪里 https://edu.cnblogs.com/campus/fz ...

  7. 老猿学5G扫盲贴:推荐三篇介绍HTTP2协议相关的文章

    专栏:Python基础教程目录 专栏:使用PyQt开发图形界面Python应用 专栏:PyQt入门学习 老猿Python博文目录 老猿学5G博文目录 5G中的服务化接口调用都是基于HTTP2协议的,老 ...

  8. flask-mail 机制

    上课无聊,总结下学习的flask-mail 机制 flask-mail 了解 flask-mail 机制中可以用pip 安装也可以用pycharm里面直接安装. flask-mail是一个能调用smt ...

  9. 第 7 篇 Scrum 冲刺博客

    每天举行会议 会议照片: 昨天已完成的工作与今天计划完成的工作及工作中遇到的困难: 成员姓名 昨天完成工作 今天计划完成的工作 工作中遇到的困难 蔡双浩 补充注释,初步查找bug 修改bug 无 陈创 ...

  10. 算法(图论)——最小生成树及其题目应用(prim和Kruskal算法实现)

    题目 n个村庄间架设通信线路,每个村庄间的距离不同,如何架设最节省开销? Kruskal算法 特点 适用于稀疏图,时间复杂度 是nlogn的. 核心思想 从小到大选取不会产生环的边. 代码实现 代码中 ...