[论文阅读] Residual Attention(Multi-Label Recognition)
Residual Attention
文章: Residual Attention: A Simple but Effective Method for Multi-Label Recognition, ICCV2021
下面说一下我对这篇文章的浅陋之见, 如有错误, 请多包涵指正.
文章的核心方法
如下图所示为其处理流程:

图中 X 为CNN骨干网络提取得到的feature, 其大小为 d*h*w , 为1个batch数据. 一般 d*h*w=2048*7*7 .
从图中可以看到, 有2个分支, 一个是 average pooling, 一个是 spatial pooling, 最后二者加权融合得到 residual attention .
Spatial pooling
其过程为:

这里有个 1*1 的卷积操作FC , 其大小为 C*d*1*1 , C 为类别数, 如果直接使用矩阵乘法计算, FC(X) 后的大小为 C*h*w .
但文章中的公式是将其展开为对每个空间点单独计算, 其中 \(\pmb{m_i}\) 为 FC 第i 个类别的参数, 其大小为 d*1*1, 计算得到的 \(s^i_j\) 为第 i 个类别在第 j 个位置的概率, \(\pmb{a^i}\) 为第 i 个类别的特征, 其大小为 d*1 .
如果, \(\pmb{m_i}\) 和 \(\pmb{a^i}\) 计算就可以得到第 i 个类别的概率. 这样就可以用到每个空间点的特征, 有利于不同目标不同类别物体的分类识别.
公式中有个温度参数 T 用来控制 \(s^i_j\) 的大小, 当 T 趋于无穷时, spatial pooling 就变成了 max pooling
Average pooling
其过程为:

上式其实就是一般分类模型的做法, 全局均值池化.
Residual Attention
如下所示, 将上述2个过程进行加权融合:

其中, \(\pmb{f^i}\) 大小为 d*1, \(\pmb{m_i}^T \pmb{f^i}\) 为第 i 个类别的概率.
至于为什么叫 Residual Attention , 文章中的说法是:
the max pooling among different spatial regions for every class, is in fact a class-specific attention operation, which can be further viewed as a residual component of the class-agnostic global average pooling.
我的理解是, 公式5形式有点像 residual 形式.
文章实验结果
多标签
如下表所示为作者对多个数据集的测试, 除了ImageNet 为单标签外, 其它都为多标签. 可以看到多标签提升还是不错的.

热力图
由于利用到了不同位置空间点的信息, 获得的 heatmap 会更加准确, 文章中给出了一张结果, 如下:

我觉得这里有个遗憾的是, 文中没有进行对比.
个人理解
关于原理
根据流程图, 结合文中作者给出的核心代码, 其基本原理就是 average pooling + max pooling.

上述代码中: y_avg 大小为 C*1, 为 average pooling ; y_max 大小为 C*1, 为 max pooling .
下面是上述代码的一个例子, y_raw 的大小为 1*3*9 , B=1, C=3, H3H, W=3:

可以看到, y_avg 刚好为 average pooling , y_max 刚好为 max pooling .
关于公式
公式中的温度参数 T 用于调整参数大小, 而给出的核心代码中, 只有T趋于无穷的情况(等价于max pooling), 对于多个 Head 的情况, T=2,3,4,5 等, 代码中是如何体现出来的?
关于效果
对于 multi-label , 使用了 spatial pooling 和 multi-head 来提高效果, 从实验结果来看, 确实有效果, 但对于单标签情况, max pooling 应该改善不大, 从实验结果上看也确实可以看到, 单标签数据集上, 最高提升了0.02个百分点.
测试代码
测试代码如下, 可以参考这里.
import torch
from torch import nn
class ResidualAttention(nn.Module):
def __init__(self, channel=512, num_class=1000, la=0.2):
super().__init__()
self.la = la
self.fc = nn.Conv2d(in_channels=channel, out_channels=num_class, kernel_size=1, stride=1, bias=False)
def forward(self, x):
y_raw = self.fc(x).flatten(2) # b, num_class, h*w
y_avg = torch.mean(y_raw, dim=2) # b, num_class
y_max = torch.max(y_raw, dim=2)[0] # b, num_class
score = y_avg + self.la * y_max
return score
if __name__ == '__main__':
channel = 4
num_class = 3
batchsize = 1
input = torch.randn(batchsize, channel, 3, 3)
resatt = ResidualAttention(channel=channel, num_class=num_class, la=0.2)
output = resatt(input)
print(output.shape)
[论文阅读] Residual Attention(Multi-Label Recognition)的更多相关文章
- 论文阅读(Xiang Bai——【PAMI2017】An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition)
白翔的CRNN论文阅读 1. 论文题目 Xiang Bai--[PAMI2017]An End-to-End Trainable Neural Network for Image-based Seq ...
- 论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline
论文阅读:Prominent Object Detection and Recognition: A Saliency-based Pipeline 如上图所示,本文旨在解决一个问题:给定一张图像, ...
- 论文笔记——Deep Residual Learning for Image Recognition
论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...
- [论文理解]Deep Residual Learning for Image Recognition
Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...
- 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》
论文阅读:Face Recognition: From Traditional to Deep Learning Methods <人脸识别综述:从传统方法到深度学习> 一.引 ...
- RAM: Residual Attention Module for Single Image Super-Resolution
1. 摘要 注意力机制是深度神经网络的一个设计趋势,其在各种计算机视觉任务中都表现突出.但是,应用到图像超分辨领域的注意力模型大都没有考虑超分辨和其它高层计算机视觉问题的天然不同. 作者提出了一个新的 ...
- [论文阅读]阿里DIEN深度兴趣进化网络之总体解读
[论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...
- Deep Residual Learning for Image Recognition (ResNet)
目录 主要内容 代码 He K, Zhang X, Ren S, et al. Deep Residual Learning for Image Recognition[C]. computer vi ...
- Deep Reinforcement Learning for Dialogue Generation 论文阅读
本文来自李纪为博士的论文 Deep Reinforcement Learning for Dialogue Generation. 1,概述 当前在闲聊机器人中的主要技术框架都是seq2seq模型.但 ...
随机推荐
- 跟我一起学Go系列:Go gRPC 安全认证方式-Token和自定义认证
Go gRPC 系列: 跟我一起学Go系列:gRPC安全认证机制-SSL/TLS认证 跟我一起学 Go 系列:gRPC 拦截器使用 跟我一起学 Go 系列:gRPC 入门必备 接上一篇继续讲 gRPC ...
- POJ 3984 迷宫(BFS)
入门BFS,第一次做,部分借鉴了大牛的 #include <iostream> #include <cstdio> #include <queue> using n ...
- 【spring源码系列】之【Bean的属性赋值】
每次进入源码的世界,就像完成一场奇妙的旅行! 1. 属性赋值概述 上一篇讲述了bean实例化中的创建实例过程,实例化后就需要对类中的属性进行依赖注入操作,本篇将重点分析属性赋值相关流程.其中属性赋值, ...
- 面试题二:JVM
JVM垃圾回收的时候如何确定垃圾? 有2种方式: 引用计数 每个对象都有一个引用计数属性,新增一个引用时计数加1,引用释放时计数减1,计数为0时可以回收: 缺点:无法解决对象循环引用的问题: 可达性分 ...
- 『无为则无心』Python函数 — 27、Python函数的返回值
目录 1.返回值概念 2.return关键字的作用 3.返回值可以返回的数据类型 4.函数如何返回多个值 5.fn5 和 fn5()的区别 6.总结: 1.返回值概念 例如:我们去超市购物,比如买饮料 ...
- mysql导入脚本
#登陆 mysql -u root -p #创建数据库 CREATE DATABASE `gps` CHARACTER SET utf8 COLLATE utf8_general_ci; #选择数据库 ...
- Ambiguous mapping found. Cannot map 'competeController' bean method
报错: Error creating bean with name 'org.springframework.web.servlet.mvc.method.annotation.RequestMapp ...
- 关于scrollview的无限滚动效果实现
起因及需求:做过阅读器的朋友应该知道,一般的阅读器都会有仿真.平移等特效.最近赶上真空期,项目不忙,有点时间,于是想抓起来,总结点干货. 仿真翻页及平滑翻页的基本实现: 仿真翻页,使用系统自带的UIP ...
- 【保姆级】Python项目(Flask网页)部署到Docker的完整过程
大家好,我是辰哥~ 前提:相信看到这篇文章的读者应该已经学会了Docker的安装以及Docker的基本使用,如果还不会的可以参考我之前的文章进行详细学习! 1.安装版:2300+字!在不同系统上安装D ...
- ARTS第十二周
1.Algorithm:每周至少做一个 leetcode 的算法题2.Review:阅读并点评至少一篇英文技术文章3.Tip:学习至少一个技术技巧4.Share:分享一篇有观点和思考的技术文章 以下是 ...