Residual Attention

文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021

下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.

文章的核心方法

如下图所示为其处理流程:

图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .

从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .

Spatial pooling

其过程为:

这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .

但文章中的公式是将其展开为对每个空间点单独计算, 其中 \(\pmb{m_i}\)​​​ 为 FCi 个类别的参数, 其大小为 d*1*1, 计算得到的 \(s^i_j\)​​​ 为第 i 个类别在第 j 个位置的概率, \(\pmb{a^i}\)​​​​ 为第 i 个类别的特征, 其大小为 d*1 .

如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.

公式中有个温度参数 T 用来控制 \(s^i_j\)​​​​ 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling

Average pooling

其过程为:

上式其实就是一般分类模型的做法, 全局均值池化.

Residual Attention

如下所示, 将上述2个过程进行加权融合:

其中, \(\pmb{f^i}\) 大小为 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 为第 i 个类别的概率.

至于为什么叫 Residual Attention , 文章中的说法是:

the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.

我的理解是, 公式5形式有点像 residual 形式.

文章实验结果

多标签

如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.

热力图

由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:

我觉得这里有个遗憾的是, 文中没有进行对比.

个人理解

关于原理

根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.

上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .

下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .

关于公式

公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?

关于效果

对于 multi-label , 使用了 spatial poolingmulti-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.

测试代码

测试代码如下, 可以参考这里.

import torch
from torch import nn class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False) def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score if __name__ == '__main__': channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)

[论文阅读] Residual Attention(Multi-Label Recognition)的更多相关文章

  1. 论文阅读(Xiang Bai——【PAMI2017】An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition)

    白翔的CRNN论文阅读 1.  论文题目 Xiang Bai--[PAMI2017]An End-to-End Trainable Neural Network for Image-based Seq ...

  2. 论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline

    论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline  如上图所示,本文旨在解决一个问题:给定一张图像, ...

  3. 论文笔记——Deep Residual Learning for Image Recognition

    论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...

  4. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  5. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  6. RAM: Residual Attention Module for Single Image Super-Resolution

    1. 摘要 注意力机制是深度神经网络的一个设计趋势,其在各种计算机视觉任务中都表现突出.但是,应用到图像超分辨领域的注意力模型大都没有考虑超分辨和其它高层计算机视觉问题的天然不同. 作者提出了一个新的 ...

  7. [论文阅读]阿里DIEN深度兴趣进化网络之总体解读

    [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...

  8. Deep Residual Learning for Image Recognition (ResNet)

    目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...

  9. Deep Reinforcement Learning for Dialogue Generation 论文阅读

    本文来自李纪为博士的论文 Deep Reinforcement Learning for Dialogue Generation. 1,概述 当前在闲聊机器人中的主要技术框架都是seq2seq模型.但 ...

随机推荐

  1. Basic remains java入门题

    Basic remains input:   b p m    读入p进制的p,m,   求p%m   ,以b进制输出 1 import java.util.*; 2 import java.math ...

  2. Docker学不会?不妨看看这篇文章

    大家好,我是辰哥! 上一篇文章(2300+字!在不同系统上安装Docker!)教大家如何在系统上安装docker,今天咱们来学习docker的基本使用. 辰哥将在本文里详细介绍docker的各种使用命 ...

  3. Java:Java多线程实现性能测试

    创建多线程和线程池 import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import ...

  4. nginx反向代理tcp协议的80端口

    需求:内网有一台mqtt协议服务器,需要将外网的mqtt请求通过一台服务器代理到内网的mqtt服务器上.而这台代理服务器不会开放出了80之外的端口,所以只能使用80端口来转发mqtt请求. 步骤:1. ...

  5. Java+Selenium3.3.1环境搭建

    一.背景和目的 selenium从2.0开始,加入了webdriver,实际上,我们说的selenium自动化测试,大部分情况都是在使用webdriver的API.现在去Selenium官网,发现最新 ...

  6. ASP.NET保存图片到sql2008

    //将图片转行为二进制的方式,存储到数据库 string name = FileUpload1.PostedFile.FileName; string type = name.Substring(na ...

  7. [刘阳Java]_eayui-pagination分页组件_第5讲

    分页组件也是很基本的应用,这里我只给出一段简单的代码 关键注意一点:分页组件可以在上面添加buttons按钮,或者自定义分页组件的外观.这些内容需要自行的查看EasyUI的API文档 <!DOC ...

  8. File类与常用IO流第七章——Properties集合

    Properties概述 java.util.Properties extends Hashtable<k,v> implements Map<k,v> Properties类 ...

  9. 在SublimeText3中搭建Verilog开发环境记录(一)

    ------------恢复内容开始------------ ------------恢复内容开始------------ ## 前言 *工欲善其事,必先利其器* 一款好用的撸码软件,能够大大的提高工 ...

  10. 学习Git的基本业务逻辑

    1,基本业务逻辑(假设针对index.html文件中内容): 1,在init版本库之前已写好开头部分:index 对index进行git init版本库: 进入到文件夹中,git init git a ...