Residual Attention

文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021

下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.

文章的核心方法

如下图所示为其处理流程:

图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .

从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .

Spatial pooling

其过程为:

这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .

但文章中的公式是将其展开为对每个空间点单独计算, 其中 \(\pmb{m_i}\)​​​ 为 FCi 个类别的参数, 其大小为 d*1*1, 计算得到的 \(s^i_j\)​​​ 为第 i 个类别在第 j 个位置的概率, \(\pmb{a^i}\)​​​​ 为第 i 个类别的特征, 其大小为 d*1 .

如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.

公式中有个温度参数 T 用来控制 \(s^i_j\)​​​​ 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling

Average pooling

其过程为:

上式其实就是一般分类模型的做法, 全局均值池化.

Residual Attention

如下所示, 将上述2个过程进行加权融合:

其中, \(\pmb{f^i}\) 大小为 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 为第 i 个类别的概率.

至于为什么叫 Residual Attention , 文章中的说法是:

the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.

我的理解是, 公式5形式有点像 residual 形式.

文章实验结果

多标签

如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.

热力图

由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:

我觉得这里有个遗憾的是, 文中没有进行对比.

个人理解

关于原理

根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.

上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .

下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .

关于公式

公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?

关于效果

对于 multi-label , 使用了 spatial poolingmulti-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.

测试代码

测试代码如下, 可以参考这里.

import torch
from torch import nn class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False) def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score if __name__ == '__main__': channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)

[论文阅读] Residual Attention(Multi-Label Recognition)的更多相关文章

  1. 论文阅读(Xiang Bai——【PAMI2017】An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition)

    白翔的CRNN论文阅读 1.  论文题目 Xiang Bai--[PAMI2017]An End-to-End Trainable Neural Network for Image-based Seq ...

  2. 论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline

    论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline  如上图所示,本文旨在解决一个问题:给定一张图像, ...

  3. 论文笔记——Deep Residual Learning for Image Recognition

    论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...

  4. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  5. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  6. RAM: Residual Attention Module for Single Image Super-Resolution

    1. 摘要 注意力机制是深度神经网络的一个设计趋势,其在各种计算机视觉任务中都表现突出.但是,应用到图像超分辨领域的注意力模型大都没有考虑超分辨和其它高层计算机视觉问题的天然不同. 作者提出了一个新的 ...

  7. [论文阅读]阿里DIEN深度兴趣进化网络之总体解读

    [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...

  8. Deep Residual Learning for Image Recognition (ResNet)

    目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...

  9. Deep Reinforcement Learning for Dialogue Generation 论文阅读

    本文来自李纪为博士的论文 Deep Reinforcement Learning for Dialogue Generation. 1,概述 当前在闲聊机器人中的主要技术框架都是seq2seq模型.但 ...

随机推荐

  1. POJ 1696 Space Ant 点积计算夹角

    题意: 一只特别的蚂蚁,只能直走或者左转.在一个平面上,有很多株植物,这只蚂蚁每天需要进食一株,这只蚂蚁从起点为(0,miny)的点开始出发.求最多能活多少天 分析: 肯定是可以吃到所有植物的,以当前 ...

  2. 什么样的CRM系统适合以客户为中心的企业?

    我们不难发现,现代的企业非常依赖CRM系统,这是因为20%的优质客户能够给企业带来80%的利润,而老客户的推荐可以带来60%的客户增长.那么,什么样的CRM系统适合企业?随着信息技术的发展,客户开始拥 ...

  3. idea中IDEA优化配置,提高启动和运行速度

    IDEA优化配置,提高启动和运行速度 IDEA默认启动配置主要考虑低配置用户,参数不高,导致 启动慢,然后运行也不流畅,这里我们需要优化下启动和运行配置: 找到idea安装的bin目录: D:\ide ...

  4. asp.net mvc中的路由

    [Route] 路由 [Route("~/")] 忽略路由前缀 [Route("person/{id:int}")] 路由内联约束 [Route("h ...

  5. ARTS第十三周(阅读Tomcat源码)

    1.Algorithm:每周至少做一个 leetcode 的算法题2.Review:阅读并点评至少一篇英文技术文章3.Tip:学习至少一个技术技巧4.Share:分享一篇有观点和思考的技术文章 考研真 ...

  6. 选择适合入门的自动化测试框架TestNG 基于Java语言的入门选择之一

    对于测试工程师新手来说,最痛苦的莫过于入门,其实只要入门3个月左右,对于自动化测试,所有的测试工程师除了喜爱,就是更爱.自动化测试工作,是从根本上解放人性,不用重复去完成鼠标的点点点,例如以下测试常常 ...

  7. c语言:逗号运算符

    #include <stdio.h> main() { int a,s,d; s=2,d=3; a=12+(s+2,d+4); printf("%d\n",a); in ...

  8. MapReduce学习总结之Combiner、Partitioner、Jobhistory

    一.Combiner 在MapReduce编程模型中,在Mapper和Reducer之间有一个非常重要的组件,主要用于解决MR性能瓶颈问题 combiner其实属于优化方案,由于带宽限制,应该尽量ma ...

  9. 前端开发入门到进阶第三集【js和jquery的执行时间与页面加载的关系】

    https://blog.csdn.net/u014179029/article/details/81603561 [原文链接]:https://www.cnblogs.com/eric-qin/p/ ...

  10. ls仅列出当前目录下的所有目录

    ls -d */ -d仅列出目录本身,而不列出其中的内容 *通配符,所有的字符 /目录的标识