摘要:本文提出了一个概念简单但对卷积神经网络非常有效的注意力模块。

本文分享自华为云社区《论文解读系列三十:无参数的注意力模块SimAm论文解读》,作者:谷雨润一麦。

摘要

本文提出了一个概念简单但对卷积神经网络非常有效的注意力模块。相比于现有的通道注意力和空间注意力机制,本文直接在网络层中推理出三维的注意力权重而且不增加任何参数量。确切地来说,本文基于著名的神经科学理论提出了通过优化能量函数来查找每个神经元的重要性。本文通过求解能量函数解析解的方式,进一步将代码实现控制在十行以内。SimAm模的另一个优势是大多数操作都是基于定义的能量函数的解决方案,因此不需要花太多的精力做结构调整。在各个视觉任务上的定量实验都表明本文提出的模块在改善卷积网络的表征能力上具有灵活性和有效性。

动机

现有的注意力基础模块存在两个问题。一个是他们只能在通道或者空间维度中的一个维度对特征进行精炼,但在空间和通道同时变化的空间缺乏灵活性。第二是他们的结构往往需要基于一系列的复杂操作,例如池化。文本基于完善的神经科学理论提出的模块很好的解决了上述两个问题。具体来说,为了让网络学习到更具区分性的神经元,本文提出直接从当前的神经元推理出三维的权重,然后反过来去优化这些神经元。为了有效的推理出三维的权重,本文基于神经科学的知识定义了一个能量函数,然后获得了该函数的解析解。

方法

在神经科学中,信息丰富的神经元通常表现出与周围神经元不同的放电模式。而且,激活神经元通常会抑制周围神经元,即空域抑制。换句话说,展现出空域抑制效应的神经元在视觉处理任务中应该被赋予更高的重要性。最简单的寻找重要神经元的方法就是度量神经元之间的线性可分性。基于这些神经科学的发现,本文针对每个神经元定义了如下的能量函数:

其中,$\hat t=w_t t+b_t, \hat x_i=w_t x_i + b_t$是$t$和$x_i$的线性变换,$t$和$x_i$是输入特征$\textbf{X}\in \mathbb{R}^{C\times H\times W}$的单通道中的目标神经元和其他神经元。$i$是在空间维度上的索引,$M=H\times M$是一个通道上的神经元的数量。$w_t$和$b_t$是线性变换的权重和偏置。式(1)中的所有值都是标量。当$\hat t=y_t$并且对其他说有神经元都有$\hat x_i =y_o$时,式(1)得到最小值,其中$y_t$和$y_o$是两个不同的值。最小化公式(1)等价于找到同一通道内目标神经元$t$​和其他神经元的线性可分性。为简单起见,本文采用二值标签并添加正则项。最终的能量函数如下式:

理论上,每个通道都会有$M$个这样能量函数,如果用像SGD这样的梯度下降算法去求解这些等式的话,计算开销将会非常大。幸运地是,等式(2)中$w_t$和$b_t$都可以快速求得解析解,如下式所示:

其中$u_t=\frac{1}{M-1}\sum{i=1}^{M-1}x_i$和$\sigma_t^2=\frac{1}{M-1}\sum{i}^{M-1}(s_i-\mu_t)^2$​是对应通道中出去神经元$t$​​​后所有神经元的均值和方差。从公式(3)和公式(4)可以看出解析解都是在单通道上得到的,因此可以合理的推测同一个通道的其他神经元也满足相同的分布。基于这个假设,就可以在所有神经元上计算均值和方差,在同一通道上的所有神经元都可以复用这个均值和方差。因此可以大大减少每个位置重复计算$\mu$和$\sigma$​的开销,最终每个位置的最小能量可以通过下式得到:

其中$\mu=\frac{1}{M}\sum{i=1}^{M}x_i$和$\hat\sigma^2=\frac{1}{M}\sum{i=1}^{M}(x_i-\hat\mu)^2$。等式(5)说明,能量$e_t^$越低,神经元$t$和周围神经元的区别越大,在视觉处理中也越重要。因此,本文通过$1/e_t^$​​​来表示每个神经元的重要性。根据Hillard等人<sup>1</sup>的研究,哺乳动物大脑中的注意力调节通常表现为对神经元反应的增益效应。因此本文直接用了缩放而不是相加的操作来做特征提炼,整个模块的提炼过程如下:

其中$\Epsilon$是$e_t^*$在所有通道和空间维度的汇总,$sigmoid$​是用来约束过大的值,它不会影响每个神经元的相对大小,因为它是一个单调函数。

​ 实际上除了计算每个通道的均值和方差外,其他所有的操作都是元素级别点对点的操作 。因此利用Pytorch可以几行代码实现公式(6)的功能,如图一所示。

图一 SimAM的pytorch风格实现

实验

CIFAR 分类实验

​ 在CIFAR 10类数据和100类数据上分别做了实验,并和其他四中注意力机制进行了对比,本文提出的模块在不增加任何参数的情况下在多个模型上都表现出了优越性,实验结果如图二所示。

图二 五种不同的注意力模块在不同模型上CIFAR图像分类任务上的top-1准确率

[1]: Hillyard, S. A., Vogel, E. K., and Luck, S. J. Sensory Gain Control (Amplification) as a Mechanism of Selective Attention: Electrophysiological and Neuroimaging evidence. Philosophical Transactions of the Royal Society of London. Series B: Biological Sciences, 353(1373): 1257–1270, 1998.

点击关注,第一时间了解华为云新鲜技术~

论文解读丨无参数的注意力模块SimAm的更多相关文章

  1. 论文解读丨表格识别模型TableMaster

    摘要:在此解决方案中把表格识别分成了四个部分:表格结构序列识别.文字检测.文字识别.单元格和文字框对齐.其中表格结构序列识别用到的模型是基于Master修改的,文字检测模型用到的是PSENet,文字识 ...

  2. 论文解读丨基于局部特征保留的图卷积神经网络架构(LPD-GCN)

    摘要:本文提出一种基于局部特征保留的图卷积网络架构,与最新的对比算法相比,该方法在多个数据集上的图分类性能得到大幅度提升,泛化性能也得到了改善. 本文分享自华为云社区<论文解读:基于局部特征保留 ...

  3. 论文解读丨【CVPR 2022】不使用人工标注提升文字识别器性能

    摘要:本文提出了一种针对文字识别的半监督方法.区别于常见的半监督方法,本文的针对文字识别这类序列识别问题做出了特定的设计. 本文分享自华为云社区<[CVPR 2022] 不使用人工标注提升文字识 ...

  4. 注意力论文解读(1) | Non-local Neural Network | CVPR2018 | 已复现

    文章转自微信公众号:[机器学习炼丹术] 参考目录: 目录 0 概述 1 主要内容 1.1 Non local的优势 1.2 pytorch复现 1.3 代码解读 1.4 论文解读 2 总结 论文名称: ...

  5. [论文解读] 阿里DIEN整体代码结构

    [论文解读] 阿里DIEN整体代码结构 目录 [论文解读] 阿里DIEN整体代码结构 0x00 摘要 0x01 文件简介 0x02 总体架构 0x03 总体代码 0x04 模型基类 4.1 基本逻辑 ...

  6. CVPR2020论文解读:三维语义分割3D Semantic Segmentation

    CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation  for 3 ...

  7. 图像分类:CVPR2020论文解读

    图像分类:CVPR2020论文解读 Towards Robust Image Classification Using Sequential Attention Models 论文链接:https:// ...

  8. 从单一图像中提取文档图像:ICCV2019论文解读

    从单一图像中提取文档图像:ICCV2019论文解读 DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D Regressi ...

  9. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

  10. NLP论文解读:无需模板且高效的语言微调模型(上)

    原创作者 | 苏菲 论文题目: Prompt-free and Efficient Language Model Fine-Tuning 论文作者: Rabeeh Karimi Mahabadi 论文 ...

随机推荐

  1. Unity3D 选择焦点切换

    using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.T ...

  2. CF85B [Embassy Queue]

    Problem 题目简述 有 \(n\) 个人分别在 \(c_i\) 的时刻来,他们都要在 \(k_1\),\(k_2\) 和 \(k_3\) 窗口干不同的事,当有后面一人也排在在同一窗口时,必须等待 ...

  3. Web SSH 的原理与在 ASP.NET Core SignalR 中的实现

    前言 有个项目,需要在前端有个管理终端可以 SSH 到主控机的终端,如果不考虑用户使用 vim 等需要在控制台内现实界面的软件的话,其实使用 Process 类型去启动相应程序就够了.而这次的需求则需 ...

  4. 沫沫漫画网Js逆向分析爬取全站资源入库处理图片合并

    网站分析 打开目标网站:https://www.momomh.com/ 选择一部漫画作为分析对象:<渴望:爱火难耐> 进到漫画详情页这里,发现并没有需要逆向分析.直接可以获取漫画信息.随便 ...

  5. 矩阵重叠 (3.18 leetcode每日打卡)

    度简单66收藏分享切换为英文关注反馈矩形以列表 [x1, y1, x2, y2] 的形式表示,其中 (x1, y1) 为左下角的坐标,(x2, y2) 是右上角的坐标. 如果相交的面积为正,则称两矩形 ...

  6. VScode 中利用virtualenv建立 Python 虚拟环境

    ! https://zhuanlan.zhihu.com/p/638114885 VScode 建立 Python 虚拟环境 主要目的:创建一个与默认 python 版本不同的 python 虚拟环境 ...

  7. MySQL调优的一些总结

    SQL优化可以从那几个方面去优化 1.基本写法优化: 1.少使用select * ,尽量使用具体字段: 2.对于条件来说等号之类两边的字段类型要一致,字符串不加单引号索引会失效: 3.尽量少使用Ord ...

  8. 【Python】【OpenCV】轮廓检测

    Code: 1 import cv2 2 import numpy as np 3 4 img = np.zeros((200, 200), dtype=np.uint8) 5 img[50:150, ...

  9. Pikachu漏洞靶场 敏感信息泄露

    敏感信息泄露 概述 由于后台人员的疏忽或者不当的设计,导致不应该被前端用户看到的数据被轻易的访问到. 比如: 通过访问url下的目录,可以直接列出目录下的文件列表; 输入错误的url参数后报错信息里面 ...

  10. Weblogic获取端口IP

    Weblogic获取端口IP 获取端口IP只为了判断哪个节点 调用 private WebMBeanServer server=new WebMBeanServer(); synlog.info(&q ...