由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。

钩子函数包括Variable的钩子和nn.Module钩子,用法相似。

一、register_hook

import torch
from torch.autograd import Variable grad_list = [] def print_grad(grad):
grad_list.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data print(grad_list)
[Variable containing:
1.5653
3.5175
[torch.FloatTensor of size 2x1]
]

二、register_forward_hook & register_backward_hook

这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。

每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None

钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。

model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data) handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()

测试LeNet网络

import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

先模拟一下单次的向前传播,

net = LeNet()
img = t.autograd.Variable((t.arange(32*32*1).view(1,1,32,32)))
net(img)
Variable containing:

Columns 0 to 7
27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9
3.7801 -15.9396
[torch.FloatTensor of size 1x10]

仿照上面示意,进行钩子注册,获取第一卷积层输出结果,

def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.data) handle = net.conv2.register_forward_hook(hook)
net(img)
# 用完hook后删除
handle.remove()

……

……

[torch.FloatTensor of size 1x16x10x10]

看看hook能识别什么

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val) class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x): return x+1 model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

可见对于目标层,其输入输出都可以获取到,

Model(
)
input val: Variable containing:
1
[torch.FloatTensor of size 1] output val: Variable containing:
2
[torch.FloatTensor of size 1] Variable containing:
2
[torch.FloatTensor of size 1]

『PyTorch』第十六弹_hook技术的更多相关文章

  1. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  2. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  3. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. Java EE业务处理流程与XML的引入

    Java EE基于MVC架构的业务处理流程 MVC架构业务处理流程 XML定义 XML是可扩展标记语言,标准通用标记语言的子集,是一种用于标记电子文件使其具有结构性的标记语言.XML被设计用于数据的存 ...

  2. 20145316《网络对抗》Exp9 Web安全基础实践学习总结

    20145316<网络对抗>Exp9 Web安全基础实践学习总结 基础问题回答 SQL注入攻击原理,如何防御 SQL注入,就是攻击者通过把SQL命令插入到Web表单递交或输入域名或页面请求 ...

  3. JavaScript中hoisting(悬置/置顶解析/预解析) 实例解释,全局对象,隐含的全局概念

    JavaScript中hoisting(悬置/置顶解析/预解析) 实例解释,全局对象,隐含的全局概念 <html> <body> <script type="t ...

  4. DNS正反向区域解析(二)

    域名查询工具 Nslookup命令 >server 202.106.0.20 #指定DNS服务器 >set q=A #指定要查询的类型(A,PTR,MX,CNAME,NS) >www ...

  5. 高并发下,php使用uniqid函数生成唯一标识符的四种方案

    PHP uniqid()函数可用于生成不重复的唯一标识符,该函数基于微秒级当前时间戳.在高并发或者间隔时长极短(如循环代码)的情况下,会出现大量重复数据.即使使用了第二个参数,也会重复,最好的方案是结 ...

  6. python列表list

    1.通过中括号[ ]括起来,用逗号分隔每个元素,元素可以是数字.字符串.布尔值.列表.元组.字典.集合 2.列表有序(体现在每次打印结果都一样),因此可通过下标索引的方式取元素,下标从0开始,li[m ...

  7. Confluence5.8部分空间名称显示为问号的解决方案

    Confluence5.8部分空间名称显示为问号的解决方案 原因: 连接MySQL的时候,有没有在连接串中指定&useUnicode=true&characterEncoding=ut ...

  8. 使用fragment添加底部导航栏

    切记:fragment一定要放在framlayout中,不然不会被替换完全(就是切换之后原来的fagment可能还会存在) main.xml <LinearLayout xmlns:androi ...

  9. 高通平台读写nv总结【转】

    本文转载自:https://blog.csdn.net/suofeng12345/article/details/52713993 一,引言      1. 什么是NV       高通平台的NV,保 ...

  10. 调试工具--console用法收藏

    1.使用console进行性能测试和计算代码运行时间:http://www.cnblogs.com/0603ljx/p/4387628.html 2.console命令详解:http://www.cn ...