pytorch的四个hook函数
训练神经网络模型有时需要观察模型内部模块的输入输出,或是期望在不修改原始模块结构的情况下调整中间模块的输出,pytorch可以用hook回调函数来实现这一功能。主要使用四个hook注册函数:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。这四个函数可以被继承nn.Module的任意模块调用,传入hook函数并进行注册,从而在执行该模块的相应阶段调用hook函数实现所需功能。
register_forward_hook(self, hook, *, prepend, with_kwargs)
为模块注册一个在该模块前向传播之后执行的回调函数。
hook(module, args, output):需执行的回调函数对象,module为当前模块引用,args为当前模块前向传播输入,output为当前模块前向传播输出。可以返回修改后的output来修改该模块前向传播输出。
prepend:将该hook函数放在回调函数列表最前面,从而最先执行,否则放在队列最后。
with_kwargs:hook函数是否传入关键字参数,如果为True,则hook额外增加关键字参数,变为 hook(module, args, kwargs, output)。注意!如果with_kwargs=False,模块传入的关键字参数将不会被捕获,坑了我一个下午。
register_forward_hook注册函数本身返回一个handle句柄,可执行handle.remove()将注册的该hook函数移除。
register_forward_pre_hook(self, hook, *, prepend, with_kwargs)
为模块注册一个在该模块前向传播之前执行的回调函数。
hook(module, args):args为该模块前向传播输入。可以返回修改后的args来修改该模块前向传播输入。
其它参数、特性与前面一致。
register_full_backward_hook(self, hook, prepend)
为模块注册一个在该模块反向传播之后执行的回调函数。
hook(module, grad_input, grad_output):grad_input与grad_output分别为该模块前向传播输入和输出的梯度。可以返回修改后的grad_input来修改该模块前向传播输入的梯度。
register_full_backward_pre_hook(self, hook, prepend)
为模块注册一个在该模块反向传播之前执行的回调函数。
hook(module, grad_output):grad_output为该模块前向传播输出的梯度。可以返回修改后的grad_output来修改这一梯度。
pytorch的四个hook函数的更多相关文章
- [PyTorch 学习笔记] 5.2 Hook 函数与 CAM 算法
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py https://gi ...
- Java语言程序设计(基础篇) 第四章 数学函数、字符和字符串
第四章 数学函数.字符和字符串 4.2 常用数学函数 方法分三类:三角函数方法(trigonometric method).指数函数方法(exponent method)和服务方法(service m ...
- SQL2005四个排名函数(row_number、rank、dense_rank和ntile)的比较
排名函数是SQL Server2005新加的功能.在SQL Server2005中有如下四个排名函数: .row_number .rank .dense_rank .ntile 下面分别介绍一下这四个 ...
- 四个排名函数(row_number、rank、dense_rank和ntile)的比较
排名函数是SQL Server2005新加的功能.在SQL Server2005中有如下四个排名函数: 1.row_number 2.rank 3.dense_rank 4.ntile 下面分别介绍一 ...
- HOOK函数(一)——进程内HOOK
什么是HOOK呢?其实很简单,HOOK就是对Windows消息进行拦截检查处理的一个函数.在Windows的消息机制中,当用户产生消息时,应用程序通过调用GetMessage函数取出消息,然后把消息放 ...
- sql 的是四个排名函数
四个排名函数的用法: http://www.cnblogs.com/xhyang110/archive/2009/10/27/1590448.html 字符串分割:http://www.cnblogs ...
- Python 全栈开发四 python基础 函数
一.函数的基本语法和特性 函数的定义 函数一词来源于数学,但编程中的「函数」概念,与数学中的函数是有很大不同的.函数是指将一组语句的集合通过一个名字(函数名)封装起来,要想执行这个函数,只需调用其函数 ...
- JMeter学习(十四)JMeter函数学习(转载)
转载自 http://www.cnblogs.com/yangxia-test JMeter函数是一些能够转化在测试树中取样器或者其他配置元件的域的特殊值.一个函数的调用就像这样:${_functio ...
- Python进阶----反射(四个方法),函数vs方法(模块types 与 instance()方法校验 ),双下方法的研究
Python进阶----反射(四个方法),函数vs方法(模块types 与 instance()方法校验 ),双下方法的研究 一丶反射 什么是反射: 反射的概念是由Smith在1982年首次提出的 ...
- Pytorch中randn和rand函数的用法
Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...
随机推荐
- 用MySQL的GROUP_CONCAT函数轻松解决多表联查的聚合问题
大家好呀,我是summo,最近遇到了一个功能需求,虽然也是CURD,但属于那种比较复杂一点的CURD,话不多说,我们先看一下需求. 需求如下: 有三张表,学生表.课程表.学生课程关联表,关联关系如下图 ...
- C#项目—彩票选号
C#彩票选号软件 今天做了一个彩票选号的小软件,将学到的知识点总结如下(新手小白,多提意见): 1.写程序的思路 实体类(属性.方法) No1. 随机数组集合(属性) No2. 创建集合对象(构造方法 ...
- 记 某List.sort()后排序结果异常
背景:某次查看日志,发现数据不符合预期,希望获取的是降序排序,但是部分数据是乱序的 已知List.sort()方法应该不会出异常,所以应该是判断先后方法出问题了 果然,因为一开始写代码时,没有考虑到差 ...
- SQL 求中位值
题目A median is defined as a number separating the higher half of a data set from the lower half. Quer ...
- Google Analytics & Ads 学习笔记
更新: 2021-09-13 Naming conversion for event category, action, label https://support.google.com/analyt ...
- 邀请参与 2022 第三季度 Flutter 开发者调查
自 Flutter 3 发布之后,我们在以移动端为中心到多平台框架的路线上稳步前行,用 Dart 2.17 的新语言特性帮助大家提升工作效率,并对核心工具进行了改进,让您在跨平台打造优秀体验时更加得心 ...
- Flutter Engage China 开发者常见问题解答 | 上篇
再次感谢大家对 Flutter Engage China 活动 的关注和积极参与!我们在活动前后收到了很多来自开发者的反馈和问题,Flutter 团队和演讲嘉宾在直播 Q&A 环节中也针对部分 ...
- 使用ValueConverters扩展实现枚举控制页面的显示
1.ValueConverters 本库包含了IValueConverter接口的的最常用的实现,ValueConverters用于从视图到视图模型的值得转换,某些情况下,可用进行反向转换.里面有一些 ...
- .net6 中 Blazor PageTitle 设置无效的解决方法
直接在 razor 页面里添加 <PageTitle>xxx</PageTitle> 标签无效时的解决方法 For using the <PageTitle> ta ...
- pytorch: grad can be implicitly created only for scalar outputs
运行这段代码 import torch import numpy as np import matplotlib.pyplot as plt x = torch.ones(2,2,requires_g ...