1 遍历子模块直接提取

对于简单的模型,可以采用直接遍历子模块的方法,取出相应name模块的输出,不对模型做任何改动。该方法的缺点在于,只能得到其子模块的输出,而对于使用nn.Sequensial()中包含很多层的模型,无法获得其指定层的输出

示例 resnet18取出layer1的输出

from torchvision.models import resnet18
import torch model = resnet18(pretrained=True)
print("model:", model)
out = []
x = torch.randn(1, 3, 224, 224)
return_layer = "layer1"
for name, module in model.named_children():
x = module(x)
if name == return_layer:
out.append(x.data)
break
print(out[0].shape) # torch.Size([1, 64, 56, 56])

2 IntermediateLayerGetter类

torchvison中提供了IntermediateLayerGetter类,该方法同样只能得到其子模块的输出,而对于使用nn.Sequensial()中包含很多层的模型,无法获得其指定层的输出

from torchvision.models._utils import IntermediateLayerGetter

IntermediateLayerGetter类的pytorch源码

class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work. Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`. Args:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
"""
_version = 2
__annotations__ = {
"return_layers": Dict[str, str],
} def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()} # 重新构建backbone,将没有使用到的模块全部删掉
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers def forward(self, x: Tensor) -> Dict[str, Tensor]:
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out

示例 使用IntermediateLayerGetter类 改 resnet34+unet 完整代码见gitee

import torch
from torchvision.models import resnet18, vgg16_bn, resnet34
from torchvision.models._utils import IntermediateLayerGetter model = resnet34()
stage_indices = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
model= IntermediateLayerGetter(model, return_layers=return_layers)
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

3 create_feature_extractor函数

使用create_feature_extractor方法,创建一个新的模块,该模块将给定模型中的中间节点作为字典返回,用户指定的键作为字符串,请求的输出作为值。该方法比 IntermediateLayerGetter方法更通用, 不局限于获得模型第一层子模块的输出。比如下面的vgg,池化层都在子模块feature中,上面的方法无法取出,因此推荐使用create_feature_extractor方法。

示例 FCN论文中以vgg为backbone,分别取出三个池化层的输出

import torch
from torchvision.models import vgg16_bn
from torchvision.models.feature_extraction import create_feature_extractor model = vgg16_bn()
model = create_feature_extractor(model, {"features.43": "pool5", "features.33": "pool4", "features.23": "pool3"})
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

4 hook函数

  hook函数是程序中预定义好的函数,这个函数处于原有程序流程当中(暴露一个钩子出来)。我们需要再在有流程中钩子定义的函数块中实现某个具体的细节,需要把我们的实现,挂接或者注册(register)到钩子里,使得hook函数对目标可用。hook 是一种编程机制,和具体的语言没有直接的关系。

  Pytorch的hook编程可以在不改变网络结构的基础上有效获取、改变模型中间变量以及梯度等信息。在pytorch中,Module对象有register_forward_hook(hook) 和 register_backward_hook(hook) 两种方法,两个的操作对象都是nn.Module类,如神经网络中的卷积层(nn.Conv2d),全连接层(nn.Linear),池化层(nn.MaxPool2d, nn.AvgPool2d),激活层(nn.ReLU)或者nn.Sequential定义的小模块等。register_forward_hook是获取前向传播的输出的,即特征图或激活值register_backward_hook是获取反向传播的输出的,即梯度值。(这边只讲register_forward_hook,其余见链接

示例 获取resnet18的avgpool层的输入输出

import torch
from torchvision.models import resnet18 model = resnet18()
fmap_block = dict() # 装feature map
def forward_hook(module, input, output):
fmap_block['input'] = input
fmap_block['output'] = output layer_name = 'avgpool'
for (name, module) in model.named_modules():
if name == layer_name:
module.register_forward_hook(hook=forward_hook) input = torch.randn(64, 3, 224, 224)
output = model(input)
print(fmap_block['input'][0].shape)
print(fmap_block['output'].shape)

  

参考

1. Pytorch提取预训练模型特定中间层的输出

2. Pytorch的hook技术——获取预训练/已训练好模型的特定中间层输出

 

取出预训练模型中间层的输出(pytorch)的更多相关文章

  1. 【tf.keras】tf.keras加载AlexNet预训练模型

    目录 从 PyTorch 中导出模型参数 第 0 步:配置环境 第 1 步:安装 MMdnn 第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件) 第 3 步:导出 PyTorc ...

  2. Pytorch——BERT 预训练模型及文本分类

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  3. pytorch预训练模型的下载地址以及解决下载速度慢的方法

    https://github.com/pytorch/vision/tree/master/torchvision/models 几乎所有的常用预训练模型都在这里面 总结下各种模型的下载地址: 1 R ...

  4. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  5. [Pytorch]Pytorch加载预训练模型(转)

    转自:https://blog.csdn.net/Vivianyzw/article/details/81061765 东风的地方 1. 直接加载预训练模型 在训练的时候可能需要中断一下,然后继续训练 ...

  6. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  7. pytorch中修改后的模型如何加载预训练模型

    问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...

  8. XLNet预训练模型,看这篇就够了!(代码实现)

    1. 什么是XLNet XLNet 是一个类似 BERT 的模型,而不是完全不同的模型.总之,XLNet是一种通用的自回归预训练方法.它是CMU和Google Brain团队在2019年6月份发布的模 ...

  9. PyTorch-网络的创建,预训练模型的加载

    本文是PyTorch使用过程中的的一些总结,有以下内容: 构建网络模型的方法 网络层的遍历 各层参数的遍历 模型的保存与加载 从预训练模型为网络参数赋值 主要涉及到以下函数的使用 add_module ...

  10. 最强 NLP 预训练模型库 PyTorch-Transformers 正式开源:支持 6 个预训练框架,27 个预训练模型

    先上开源地址: https://github.com/huggingface/pytorch-transformers#quick-tour 官网: https://huggingface.co/py ...

随机推荐

  1. Mac怎么创建加密文件夹

    对于一些使用Mac工作生活有特殊要求以及职业要求有限制的用户来说,加密自己的工作内容以及隐私是非常重要的一件事情.往往用户需要加密的内容项目很多,这个时候我们就需要一个加密文件夹来包含这些内容.那么M ...

  2. 循环2-if与case语法

    一.if语法结构 1. 单分支结构 if < 条件表达式 > then 指令 fi 或者 if < 条件表达式 >:then 指令 fi 2. 双分支结构 if < 条件 ...

  3. cadence报错because the library part is newer than the part in the design cache.Select the part in the cache and choose Design-Update Cache,and then place the part again.

    cadence报错because the library part is newer than the part in the design cache.Select the part in the ...

  4. 错题笔记:int a=b=1这样定义为什么是错误的

    C语言中定义同一类型的多个变量必须以逗号分隔.如: int a,b,c ; =在C语言中是赋值运算符,等号左边的变量,必须是已以定义好的变量才可以. int a=b=1 ; 中,若b已经定义,则是正确 ...

  5. 前端面试问题整理(html和css部分)

    html5新增属性有哪些? 如何理解语义化标签? 你如何看待前端模块化的? 如何看待前后端分离? 浏览器兼容性问题? 你知道的行内元素.块级元素有哪些? css部分: 1.为什么要初始化css样式? ...

  6. kali linux|01.kali下安装Nessus

    Kali安装Nessus 说明 Nessus是一款基于插件的系统漏洞扫描和分析软件 一.安装 1.下载安装包 https://www.tenable.com/downloads/nessus 查看ka ...

  7. 5、Excel—在做车辆费用汇总的时候,复制出来的数据跟同事的一样,但是合计总数就不一样

    在做车辆费用汇总的时候,复制出来的数据跟同事的一样,但是合计总数就不一样, 刚开始以为是数值问题,明明两份Excel都是同样的数据,为什么合计就不一样呢?(根据同事那份的数据,然后手动在自己的exce ...

  8. 【python】第一模块 步骤五 第二课、Python多线程

    第二课.Python多线程 一.课程介绍 1.1 课程概要 章节概要 进程.线程与并发 对多核的利用 实现一个线程 线程之间的通信 线程的调度和优化 1.2 为什么要学习多线程 (线程)使用场景 快速 ...

  9. layui相关问题总结

    1.layui table回显选中 1) radio: done:function(res, curr, count){ for(var i = 0; i < res.data.length; ...

  10. uni-app微信小程序文本框计数功能

    <view> <textarea placeholder="请输入" @input="sumfontnum" :maxlength=" ...