Residual Attention

文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021

下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.

文章的核心方法

如下图所示为其处理流程:

图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .

从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .

Spatial pooling

其过程为:

这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .

但文章中的公式是将其展开为对每个空间点单独计算, 其中 \(\pmb{m_i}\)​​​ 为 FCi 个类别的参数, 其大小为 d*1*1, 计算得到的 \(s^i_j\)​​​ 为第 i 个类别在第 j 个位置的概率, \(\pmb{a^i}\)​​​​ 为第 i 个类别的特征, 其大小为 d*1 .

如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.

公式中有个温度参数 T 用来控制 \(s^i_j\)​​​​ 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling

Average pooling

其过程为:

上式其实就是一般分类模型的做法, 全局均值池化.

Residual Attention

如下所示, 将上述2个过程进行加权融合:

其中, \(\pmb{f^i}\) 大小为 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 为第 i 个类别的概率.

至于为什么叫 Residual Attention , 文章中的说法是:

the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.

我的理解是, 公式5形式有点像 residual 形式.

文章实验结果

多标签

如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.

热力图

由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:

我觉得这里有个遗憾的是, 文中没有进行对比.

个人理解

关于原理

根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.

上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .

下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .

关于公式

公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?

关于效果

对于 multi-label , 使用了 spatial poolingmulti-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.

测试代码

测试代码如下, 可以参考这里.

import torch
from torch import nn class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False) def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score if __name__ == '__main__': channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)

[论文阅读] Residual Attention(Multi-Label Recognition)的更多相关文章

  1. 论文阅读(Xiang Bai——【PAMI2017】An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition)

    白翔的CRNN论文阅读 1.  论文题目 Xiang Bai--[PAMI2017]An End-to-End Trainable Neural Network for Image-based Seq ...

  2. 论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline

    论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline  如上图所示,本文旨在解决一个问题:给定一张图像, ...

  3. 论文笔记——Deep Residual Learning for Image Recognition

    论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...

  4. [论文理解]Deep Residual Learning for Image Recognition

    Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...

  5. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  6. RAM: Residual Attention Module for Single Image Super-Resolution

    1. 摘要 注意力机制是深度神经网络的一个设计趋势,其在各种计算机视觉任务中都表现突出.但是,应用到图像超分辨领域的注意力模型大都没有考虑超分辨和其它高层计算机视觉问题的天然不同. 作者提出了一个新的 ...

  7. [论文阅读]阿里DIEN深度兴趣进化网络之总体解读

    [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...

  8. Deep Residual Learning for Image Recognition (ResNet)

    目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...

  9. Deep Reinforcement Learning for Dialogue Generation 论文阅读

    本文来自李纪为博士的论文 Deep Reinforcement Learning for Dialogue Generation. 1,概述 当前在闲聊机器人中的主要技术框架都是seq2seq模型.但 ...

随机推荐

  1. FlowNet:simple / correlation 与 相关联操作

    Flow Net : simple / correlation 与 相关联操作 ​ 上一篇文章中(还没来得及写),已经简单的讲解了光流是什么以及光流是如何求得的.同时介绍了几个光流领域的经典传统算法. ...

  2. WebSocket实现实时聊天系统

    WebSocket实现实时聊天系统 等闲变却故人心,却道故人心易变. 简介:前几天看了WebSocket,今天体验下它的实时聊天. 一.项目介绍 WebSocket 实时聊天系统自己一个一码的搞出来还 ...

  3. Nginx:Nginx日志切割方法

    Nginx的日志文件是没有切割(rotate)功能的,但是我们可以写一个脚本来自动切割日志文件. 首先我们要注意两点: 1.切割的日志文件是不重名的,所以需要我们自定义名称,一般就是时间日期做文件名. ...

  4. Spring:Spring项目多接口实现类报错找不到指定类

    spring可以通过applicationContext.xml进行配置接口实现类 applicationContext.xml中可以添加如下配置: 在application.properties中添 ...

  5. HMM实现中文分词

    链接:https://pan.baidu.com/s/1uBjLC61xm4tQ9raDa_M1wQ  提取码:f7l1 推荐:https://blog.csdn.net/longgb123/arti ...

  6. MySQL 数据排序 order by

    1.单一字段排序 select * from tablename order by field1 desc; 排序采用order by+排序字段 升序关键字(asc,desc),排序字段可以放多个,多 ...

  7. Java核心基础第4篇-Java数组的常规操作

    Java数组 一.数组简介 数组是多个相同类型数据的组合,实现对这些数据的统一管理 数组属引用类型,数组型数据是对象(Object) 数组中的元素可以是任何数据类型,包括基本类型和引用类型 数组类型是 ...

  8. asp.net mvc中的路由

    [Route] 路由 [Route("~/")] 忽略路由前缀 [Route("person/{id:int}")] 路由内联约束 [Route("h ...

  9. 如何在Apache HttpClient中设置TLS版本

    1.简介 Apache HttpClient是一个底层.轻量级的客户端HTTP库,用于与HTTP服务器进行通信. 在本教程中,我们将学习如何在使用HttpClient时配置支持的传输层安全(TLS)版 ...

  10. 交通规则:HOV车道

    多乘员车道的限行时间一般为工作日上下班高峰,车上只有一个人时不能走该车道