由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。

钩子函数包括Variable的钩子和nn.Module钩子,用法相似。

一、register_hook

import torch
from torch.autograd import Variable grad_list = [] def print_grad(grad):
grad_list.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data print(grad_list)
[Variable containing:
1.5653
3.5175
[torch.FloatTensor of size 2x1]
]

二、register_forward_hook & register_backward_hook

这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。

每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None

钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。

model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data) handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()

测试LeNet网络

import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

先模拟一下单次的向前传播,

net = LeNet()
img = t.autograd.Variable((t.arange(32*32*1).view(1,1,32,32)))
net(img)
Variable containing:

Columns 0 to 7
27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9
3.7801 -15.9396
[torch.FloatTensor of size 1x10]

仿照上面示意,进行钩子注册,获取第一卷积层输出结果,

def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.data) handle = net.conv2.register_forward_hook(hook)
net(img)
# 用完hook后删除
handle.remove()

……

……

[torch.FloatTensor of size 1x16x10x10]

看看hook能识别什么

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val) class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x): return x+1 model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

可见对于目标层,其输入输出都可以获取到,

Model(
)
input val: Variable containing:
1
[torch.FloatTensor of size 1] output val: Variable containing:
2
[torch.FloatTensor of size 1] Variable containing:
2
[torch.FloatTensor of size 1]

『PyTorch』第十六弹_hook技术的更多相关文章

  1. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  2. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  3. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. Java(20~24)

    1.Collection中的集合称为单列集合,Map中的集合称为双列集合(键值对集合). 2.Map常用方法:map.put()   map.get()   map.remove()   map.ke ...

  2. Applying the Kappa architecture in the telco industry

    https://www.oreilly.com/ideas/applying-the-kappa-architecture-in-the-telco-industry Kappa architectu ...

  3. 02: Redis缓存系统

    目录: 1.1 在centos6.5中安装Redis 1.2 Redis的简介及两种基本操作 1.3 Redis对string操作(第一类) 1.4 redis对Hash操作,字典格式(第二类) 1. ...

  4. 20145104张家明 《Java程序设计》第四次实验设计

    20145104张家明 <Java程序设计>第四次实验设计 这第四次实验报告 我们开始着手安卓了 在电脑上安装了安卓虚拟机

  5. 20145303刘俊谦 Exp7 网络欺诈技术防范

    20145303刘俊谦 Exp7 网络欺诈技术防范 1.实验后回答问题 (1)通常在什么场景下容易受到DNS spoof攻击 局域网内的攻击,arp入侵攻击和DNS欺骗攻击 公共wifi点上的攻击. ...

  6. Python字典猜解

    摘要 目标 使用Python破解WordPress用户密码 使用Python破解zip压缩包密码 思路 通过表单提交项构建数据包,使用字典中的可选字符进行逐一排列组合暴力破解WordPress的用户密 ...

  7. 【转】linux之cp/scp命令+scp命令详解

    linux之cp/scp命令+scp命令详解   名称:cp 使用权限:所有使用者 使用方式: cp [options] source dest cp [options] source... dire ...

  8. 论文笔记——MobileNets(Efficient Convolutional Neural Networks for Mobile Vision Applications)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications MobileNet由Go ...

  9. [JavaScript] - form表单转json的插件

    jquery.serializejson.js 之前好像记录过,做项目又用到了再记下 在页面中引入js后就可以使用了 示例: //点击设置微信信息的form表单提交按钮后,执行wxConfig的con ...

  10. [SpringBoot] - 上线一份项目记录

    首先在服务器上运行war包. (新建项目) 其后,选择数据库,因为之前感觉mysql比较难安装,这次就再试一次,之前的PostgreSQL没有问题. 将原有文件进行复制,排除导包错误. 首先测试邮件发 ...