训练神经网络模型有时需要观察模型内部模块的输入输出,或是期望在不修改原始模块结构的情况下调整中间模块的输出,pytorch可以用hook回调函数来实现这一功能。主要使用四个hook注册函数:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。这四个函数可以被继承nn.Module的任意模块调用,传入hook函数并进行注册,从而在执行该模块的相应阶段调用hook函数实现所需功能。

register_forward_hook(self, hook, *, prepend, with_kwargs)

  为模块注册一个在该模块前向传播之后执行的回调函数。

  hook(module, args, output):需执行的回调函数对象,module为当前模块引用,args为当前模块前向传播输入,output为当前模块前向传播输出。可以返回修改后的output来修改该模块前向传播输出。

  prepend:将该hook函数放在回调函数列表最前面,从而最先执行,否则放在队列最后。

  with_kwargs:hook函数是否传入关键字参数,如果为True,则hook额外增加关键字参数,变为 hook(module, args, kwargs, output)。注意!如果with_kwargs=False,模块传入的关键字参数将不会被捕获,坑了我一个下午。

  register_forward_hook注册函数本身返回一个handle句柄,可执行handle.remove()将注册的该hook函数移除。

register_forward_pre_hook(self, hook, *, prepend, with_kwargs)

  为模块注册一个在该模块前向传播之前执行的回调函数。

  hook(module, args):args为该模块前向传播输入。可以返回修改后的args来修改该模块前向传播输入。

  其它参数、特性与前面一致。

register_full_backward_hook(self, hook, prepend)

  为模块注册一个在该模块反向传播之后执行的回调函数。

  hook(module, grad_input, grad_output):grad_input与grad_output分别为该模块前向传播输入和输出的梯度。可以返回修改后的grad_input来修改该模块前向传播输入的梯度。

register_full_backward_pre_hook(self, hook, prepend)

  为模块注册一个在该模块反向传播之前执行的回调函数。

  hook(module, grad_output):grad_output为该模块前向传播输出的梯度。可以返回修改后的grad_output来修改这一梯度。

 

pytorch的四个hook函数的更多相关文章

  1. [PyTorch 学习笔记] 5.2 Hook 函数与 CAM 算法

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py https://gi ...

  2. Java语言程序设计(基础篇) 第四章 数学函数、字符和字符串

    第四章 数学函数.字符和字符串 4.2 常用数学函数 方法分三类:三角函数方法(trigonometric method).指数函数方法(exponent method)和服务方法(service m ...

  3. SQL2005四个排名函数(row_number、rank、dense_rank和ntile)的比较

    排名函数是SQL Server2005新加的功能.在SQL Server2005中有如下四个排名函数: .row_number .rank .dense_rank .ntile 下面分别介绍一下这四个 ...

  4. 四个排名函数(row_number、rank、dense_rank和ntile)的比较

    排名函数是SQL Server2005新加的功能.在SQL Server2005中有如下四个排名函数: 1.row_number 2.rank 3.dense_rank 4.ntile 下面分别介绍一 ...

  5. HOOK函数(一)——进程内HOOK

    什么是HOOK呢?其实很简单,HOOK就是对Windows消息进行拦截检查处理的一个函数.在Windows的消息机制中,当用户产生消息时,应用程序通过调用GetMessage函数取出消息,然后把消息放 ...

  6. sql 的是四个排名函数

    四个排名函数的用法: http://www.cnblogs.com/xhyang110/archive/2009/10/27/1590448.html 字符串分割:http://www.cnblogs ...

  7. Python 全栈开发四 python基础 函数

    一.函数的基本语法和特性 函数的定义 函数一词来源于数学,但编程中的「函数」概念,与数学中的函数是有很大不同的.函数是指将一组语句的集合通过一个名字(函数名)封装起来,要想执行这个函数,只需调用其函数 ...

  8. JMeter学习(十四)JMeter函数学习(转载)

    转载自 http://www.cnblogs.com/yangxia-test JMeter函数是一些能够转化在测试树中取样器或者其他配置元件的域的特殊值.一个函数的调用就像这样:${_functio ...

  9. Python进阶----反射(四个方法),函数vs方法(模块types 与 instance()方法校验 ),双下方法的研究

    Python进阶----反射(四个方法),函数vs方法(模块types 与 instance()方法校验 ),双下方法的研究 一丶反射 什么是反射: ​ 反射的概念是由Smith在1982年首次提出的 ...

  10. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

随机推荐

  1. Maven 换源

    ~/.m2/settings.xml <settings> ... <mirrors> ... <mirror> <id>alimaven</id ...

  2. Unrecognized SSL message, plaintext connection?

    报错:Unrecognized SSL message, plaintext connection? 修改:把 requestContext.setScheme(Scheme.HTTPS);修改为 r ...

  3. NumPy从入门到放弃

    看前建议: 本文以jupyter notebook为编辑器进行示例,建议有一定python基础后再进行学习. python的安装:https://www.cnblogs.com/scfssq/p/17 ...

  4. 神经网络之卷积篇:详解池化层(Pooling layers)

    详解池化层 除了卷积层,卷积网络也经常使用池化层来缩减模型的大小,提高计算速度,同时提高所提取特征的鲁棒性,来看一下. 先举一个池化层的例子,然后再讨论池化层的必要性.假如输入是一个4×4矩阵,用到的 ...

  5. 【YashanDB知识库】服务端是GBK编码,导致从22.2.12.100升级到22.2.13.100失败问题

    问题现象 问题单:22.2.12.100升级到22.2.13.100失败 现象:如下图,从22.2.12.100升级到22.2.13.100失败,报错. 问题风险及影响 版本升级失败,影响上线 问题发 ...

  6. 商业银行国际结算规模创新高,合合信息AI助力金融行业智能处理多版式文档

    随着我国外贸新业态的快速增长,银行国际结算业务在服务实体经济发展.促进贸易投资便利化进程中发挥了越来越重要的作用.根据中国银行业协会近日发布的<中国贸易金融行业发展报告(2023-2024)&g ...

  7. TS-TCC: 通过时序和上下文对比学习时间序列表征《Time-Series Representation Learning via Temporal and Contextual Contrasting》(时间序列、时序表征、时间和上下文对比、对比学习、自监督学习、半监督学习)

    现在是2023年11月14日的22:15,肝不动了,要不先回寝室吧,明天把这篇看了,然后把文档写了.OK,明天的To Do List. 现在是2023年11月15日的10:35,继续. 论文:Time ...

  8. 2021年12月墨天轮国产数据库排行榜: openGauss节节攀升拿下榜眼,GaussDB与TDSQL你争我夺各进一位

    2021年12月的国产数据库流行度排行榜已在墨天轮发布,本月共有189家数据库参与排名.为使国产数据库排名更加专业与客观,本月起,排行榜加入了三方评测.生态.专利数.论文数等新的指标.其中三方测评方面 ...

  9. ssh建立github连接 基于ssh密钥

    1. 建立公钥和私钥 ps:公钥放在github上面的,私钥放在自己本地电脑 : 先生成密钥:打开 gitbash 输入命令: ssh-keygen -t rsa -b 4096 -C "z ...

  10. 58. vue常用的api

    1. nextTick  使用场景:通过异步渲染的页面解构不能直接dom操作,要使用 nextTick (延迟回调)等待一下 :nextTick 的作用:感知dom的更新完成,类似于 updated ...