本章代码:

这篇文章主要介绍了如何使用 Hook 函数提取网络中的特征图进行可视化,和 CAM(class activation map, 类激活图)

Hook 函数概念

Hook 函数是在不改变主体的情况下,实现额外功能。由于 PyTorch 是基于动态图实现的,因此在一次迭代运算结束后,一些中间变量如非叶子节点的梯度和特征图,会被释放掉。在这种情况下想要提取和记录这些中间变量,就需要使用 Hook 函数。

PyTorch 提供了 4 种 Hook 函数。

torch.Tensor.register_hook(hook)

功能:注册一个反向传播 hook 函数,仅输入一个参数,为张量的梯度。

hook函数:

hook(grad)

参数:

  • grad:张量的梯度

代码如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b) # 保存梯度的 list
a_grad = list() # 定义 hook 函数,把梯度添加到 list 中
def grad_hook(grad):
a_grad.append(grad) # 一个张量注册 hook 函数
handle = a.register_hook(grad_hook) y.backward() # 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 查看在 hook 函数里 list 记录的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()

结果如下:

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])

在反向传播结束后,非叶子节点张量的梯度被清空了。而通过hook函数记录的梯度仍然可以查看。

hook函数里面可以修改梯度的值,无需返回也可以作为新的梯度赋值给原来的梯度。代码如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b) a_grad = list() def grad_hook(grad):
grad *= 2
return grad*3 handle = w.register_hook(grad_hook) y.backward() # 查看梯度
print("w.grad: ", w.grad)
handle.remove()

结果是:

w.grad:  tensor([30.])

torch.nn.Module.register_forward_hook(hook)

功能:注册 module 的前向传播hook函数,可用于获取中间的 feature map。

hook函数:

hook(module, input, output)

参数:

  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

下面代码执行的功能是 $3 \times 3$ 的卷积和 $2 \times 2$ 的池化。我们使用register_forward_hook()记录中间卷积层输入和输出的 feature map。

    class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input) # 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_() # 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook) # inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img) # 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

输出如下:

output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9., 9.],
[ 9., 9.]],
[[18., 18.],
[18., 18.]]]], grad_fn=<ThnnConv2DBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)

torch.Tensor.register_forward_pre_hook()

功能:注册 module 的前向传播前的hook函数,可用于获取输入数据。

hook函数:

hook(module, input)

参数:

  • module:当前网络层
  • input:当前网络层输入数据

torch.Tensor.register_backward_hook()

功能:注册 module 的反向传播的hook函数,可用于获取梯度。

hook函数:

hook(module, grad_input, grad_output)

参数:

  • module:当前网络层
  • input:当前网络层输入的梯度数据
  • output:当前网络层输出的梯度数据

代码如下:

    class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input) def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input)) def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output)) # 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_() # 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook) # inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img) loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

输出如下:

forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]],
[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
[0.0000, 0.0000]],
[[0.5000, 0.0000],
[0.0000, 0.0000]]]]),)

hook函数实现机制

hook函数实现的原理是在module__call()__函数进行拦截,__call()__函数可以分为 4 个部分:

  • 第 1 部分是实现 _forward_pre_hooks
  • 第 2 部分是实现 forward 前向传播
  • 第 3 部分是实现 _forward_hooks
  • 第 4 部分是实现 _backward_hooks

由于卷积层也是一个module,因此可以记录_forward_hooks

    def __call__(self, *input, **kwargs):
# 第 1 部分是实现 _forward_pre_hooks
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result # 第 2 部分是实现 forward 前向传播
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs) # 第 3 部分是实现 _forward_hooks
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result # 第 4 部分是实现 _backward_hooks
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result

Hook 函数提取网络的特征图

下面通过hook函数获取 AlexNet 每个卷积层的所有卷积核参数,以形状作为 key,value 对应该层多个卷积核的 list。然后取出每层的第一个卷积核,形状是 [1, in_channle, h, w],转换为 [in_channle, 1, h, w],使用 TensorBoard 进行可视化,代码如下:

    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # 数据
path_img = "imgs/lena.png" # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768] norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
norm_transform
]) img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0) # chw --> bchw # 模型
alexnet = models.alexnet(pretrained=True) # 注册hook
fmap_dict = dict()
for name, sub_module in alexnet.named_modules(): if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list())
# 由于AlexNet 使用 nn.Sequantial 包装,所以 name 的形式是:features.0 features.1
n1, n2 = name.split(".") def hook_func(m, i, o):
key_name = str(m.weight.shape)
fmap_dict[key_name].append(o) alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) # forward
output = alexnet(img_tensor) # add image
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]# 取出第一个卷积核的参数
fmap.transpose_(0, 1) # 把 BCHW 转换为 CBHW nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

使用 TensorBoard 进行可视化如下:

CAM(class activation map, 类激活图)

暂未完成。列出两个参考文章。

参考资料

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

[PyTorch 学习笔记] 5.2 Hook 函数与 CAM 算法的更多相关文章

  1. Hadoop源码学习笔记(2) ——进入main函数打印包信息

    Hadoop源码学习笔记(2) ——进入main函数打印包信息 找到了main函数,也建立了快速启动的方法,然后我们就进去看一看. 进入NameNode和DataNode的主函数后,发现形式差不多: ...

  2. JavaScript学习笔记(七)——函数的定义与调用

    在学习廖雪峰前辈的JavaScript教程中,遇到了一些需要注意的点,因此作为学习笔记列出来,提醒自己注意! 如果大家有需要,欢迎访问前辈的博客https://www.liaoxuefeng.com/ ...

  3. Python学习笔记014——迭代工具函数 内置函数enumerate()

    1 描述 enumerate() 函数用于将一个可遍历的数据对象(如列表.元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中. 2 语法 enumerate(sequ ...

  4. C语言学习笔记---好用的函数memcpy与memset

    这个主要用于我个人的学习笔记,便于以后查询,顺便分享给大家. 想必在用C的时候难免会与数组,指针,内存这几样东西打交道,先以数组为例,例如有一个数组int a[5] = {1, 2, 3, 4, 5} ...

  5. C++学习笔记之——内联函数,引用

    本文为原创作品,转载请注明出处 欢迎关注我的博客:http://blog.csdn.net/hit2015spring和http://www.cnblogs.com/xujianqing/ 作者:晨凫 ...

  6. 大一C语言学习笔记(5)---函数篇-定义函数需要了解注意的地方;定义函数的易错点;详细说明函数的每个组合部分的功能及注意事项

    博主学习C语言是通过B站上的<郝斌C语言自学教程>,对于C语言初学者来说,我认为郝斌真的是在全网C语言学习课程中讲的最全面,到位的一个,这个不是真不是博主我吹他哈,大家可以去B站去看看,C ...

  7. 学习笔记之——C语言 函数

    采用函数的原因: 随着程序规模的变大,产生了以下问题: --main函数变得相当冗杂 --程序复杂度不断提高 --代码前后关联度提高,修改代码往往牵一发而动全身 --变量使用过多,命名都成了问题 -- ...

  8. OpenCV 学习笔记 04 深度估计与分割——GrabCut算法与分水岭算法

    1 使用普通摄像头进行深度估计 1.1 深度估计原理 这里会用到几何学中的极几何(Epipolar Geometry),它属于立体视觉(stereo vision)几何学,立体视觉是计算机视觉的一个分 ...

  9. [转]Python3《机器学习实战》学习笔记(一):k-近邻算法(史诗级干货长文)

    转自http://blog.csdn.net/c406495762/article/details/75172850 版权声明:本文为博主原创文章,未经博主允许不得转载.   目录(?)[-] 一 简 ...

随机推荐

  1. CF1349F 【Slime and Sequences】part2

    由于本文过长,\(\LaTeX\) 炸了,分两篇,part1 优化 我们假装不会欧拉数的通项式(其实是因为它的通项式不容易继续优化?),使用容斥代替掉欧拉数 设 \(\begin{vmatrix}n\ ...

  2. mPaaS 小程序架构解析 | 实操演示小程序如何实现多端开发

    对于 mPaaS 小程序开发框架,想必读者们并不陌生.它源自于支付宝小程序框架,继承了易开发性.跨平台性及 Native 性能,不仅帮助开发者实现面向自有 App 投放小程序,还可快速构建打包,覆盖支 ...

  3. JS 与 jQery 的区别主要在于 DOM

    //目前正在学习前端阶段,把知识点整理.保存下来以便日后查看 首先引入jQery: 需要先引入css,再引入js: jQery需要在js前引入,再引入框架,最后才是js的引入:css也相同,先引入框架 ...

  4. C#LeetCode刷题之#160-相交链表(Intersection of Two Linked Lists)

    问题 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/3824 访问. 编写一个程序,找到两个单链表相交的起始节点. 例如 ...

  5. 利用BeautifulSoup去除HTML指定标签和去除注释

    去除指定标签 from bs4 import BeautifulSoup #去除属性ul [s.extract() for s in soup("ul")] # 去除属性svg [ ...

  6. bluecms v1.6 代码审计

    0x01 使用seay源代码审计系统进行审计 扫描到了很多个可疑漏洞,不过工具都有一定的误报,下面我们就逐个进行验证 0x02 /ad_js.php SQL注入漏洞 查看源码,我们发现程序通过GET方 ...

  7. CPF 入门教程 - 数据绑定和命令绑定(二)

    CPF netcore跨平台UI框架 系列教程 CPF 入门教程(一) CPF 入门教程 - 数据绑定和命令绑定(二) 数据绑定和Wpf类似,支持双向绑定.数据绑定和命令绑定是UI和业务逻辑分离的基础 ...

  8. GRE 协议 - 和 ISP 用的协议不一样怎么办

    GRE 出现的背景: 随着网络(公司)规模的增大,越来越多的公司需要在跨区域之间建设自己的分公司.但随之也就出现了这样的问题,考虑这样一个场景.公司 A 在北京和上海间开设了两家公司,由于业务的需要, ...

  9. 通过http、https域名访问静态网页、nginx配置负载均衡(nginx配置)

    很多场景下需要可以通过浏览器访问静态网页,不想把服务器ip地址直接暴露出来,通过nginx可以解决这个问题. 实现http域名访问静态网页 1.域名解析配置(本文都是以阿里云为例,其他平台,操作步骤类 ...

  10. CallStub相关

    CallStub相关 调用入口 share/vm/runtime/stubRoutines.hpp // Calls to Java SimonNote: 函数指针结合typedef类型定义 type ...