『PyTorch』第十六弹_hook技术
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。
钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
一、register_hook
import torch
from torch.autograd import Variable grad_list = [] def print_grad(grad):
grad_list.append(grad) x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data print(grad_list)
[Variable containing:
1.5653
3.5175
[torch.FloatTensor of size 2x1]
]
二、register_forward_hook & register_backward_hook
这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。
每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None。
钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。
model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data) handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()
测试LeNet网络
import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
先模拟一下单次的向前传播,
net = LeNet()
img = t.autograd.Variable((t.arange(32*32*1).view(1,1,32,32)))
net(img)
Variable containing: Columns 0 to 7
27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9
3.7801 -15.9396
[torch.FloatTensor of size 1x10]
仿照上面示意,进行钩子注册,获取第一卷积层输出结果,
def hook(module, inputdata, output):
'''把这层的输出拷贝到features中'''
print(output.data) handle = net.conv2.register_forward_hook(hook)
net(img)
# 用完hook后删除
handle.remove()
……
……
[torch.FloatTensor of size 1x16x10x10]
看看hook能识别什么
import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val) class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x): return x+1 model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()
可见对于目标层,其输入输出都可以获取到,
Model(
)
input val: Variable containing:
1
[torch.FloatTensor of size 1] output val: Variable containing:
2
[torch.FloatTensor of size 1] Variable containing:
2
[torch.FloatTensor of size 1]
『PyTorch』第十六弹_hook技术的更多相关文章
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询
一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...
- 『PyTorch』第十四弹_torch.nn.Module类属性
nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『MXNet』第十二弹_再谈新建计算节点
上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
随机推荐
- python的数据结构之数字和字符串(四)
一.数字 Python Number 数据类型用于存储数值.数据类型是不允许改变的,这就意味着如果改变 Number 数据类型的值,将重新分配内存空间. Python 支持四种不同的数值类型: 整型( ...
- WindowsServer-性能计数器
https://jingyan.baidu.com/article/59703552e764e48fc00740dd.html
- 简单的HTML5 canvas游戏工作原理
HTML5已经不是一个新名词.它看上去很cool,有很多feature,大多数人普遍看好它的发展.对于我来说,最感兴趣的是它的canvas标签,可以结合Javascript来绘制游戏画面. 我们可以在 ...
- Jsoup请求http或https返回json字符串工具类
Jsoup请求http或https返回json字符串工具类 所需要的jar包如下: jsoup-1.8.1.jar 依赖jar包如下: httpclient-4.5.4.jar; httpclient ...
- 解决mysql的Too many connections
解决: /etc/my.cnf vim编辑 添加 max_connections= wait_timeout= 然后执行code service mysqld reload service mysql ...
- php在Nginx环境下进行刷新缓存立即输出,实现常驻进程轮询。
以下面这段代码并不会逐个输出,而是当浏览器筹够一定字节数进行统一输出,结果显而易见,10秒后一次性输出所有内容 for($i=0;$i<10;$i++){ echo $i.'</br> ...
- 20145307陈俊达_安卓逆向分析_Xposed的hook技术研究
20145307陈俊达_安卓逆向分析_Xposed的hook技术研究 引言 其实这份我早就想写了,xposed这个东西我在安卓SDK 4.4.4的时候就在玩了,root后安装架构,起初是为了实现一些屌 ...
- HDU 2841 Visible Trees(容斥)题解
题意:有一块(1,1)到(m,n)的地,从(0,0)看能看到几块(如果两块地到看的地方三点一线,后面的地都看不到). 思路:一开始是想不到容斥...后来发现被遮住的地都有一个特点,若(a,b)有gcd ...
- 【Git安装】centos安装git
1 yum install git 安装后的默认存放地点/usr/bin/git
- BZOJ4415: [Shoi2013]发牌 树状数组+二分
Description 假设一开始,荷官拿出了一副新牌,这副牌有N张不同的牌,编号依次为1到N.由于是新牌,所以牌是按照顺序排好的,从牌库顶开始,依次为1, 2,……直到N,N号牌在牌库底.为了发完所 ...