论文提出类特定控制门CSG来引导网络学习类特定的卷积核,并且加入正则化方法来稀疏化CSG矩阵,进一步保证类特定。从实验结果来看,CSG的稀疏性能够引导卷积核与类别的强关联,在卷积核层面产生高度类相关的特征表达,从而提升网络的性能以及可解释性



来源:晓飞的算法工程笔记 公众号

论文: Training Interpretable Convolutional Neural Networks by Differentiating Class-specific Filters

Introduction


  卷积神经网络虽然在多个视觉任务中有很好的表现,但可解释性的欠缺导致其在需要人类信任或互动的应用中受到限制,而论文认为类别与卷积核间的多对多关系是造成卷积网络可解释性差的主要原因,称之为filter-class entanglement。如上图所示,卷积网络通常提取包含多个语义概念的混合特征,比如类别、场景和颜色等,去除entanglement能够更好地解释每个卷积核的作用。

  受细胞分化的启发,论文提出在最后的卷积层中学习类特定卷积核,希望卷积核能够"分化"成针对不同类别的分组,如图1右所示,单个卷积核专门负责特定类别的识别。为了实现这个想法,论文设计了可学习的类特定门控CSG(Class-Specific Gate)来引导将卷积核分配给不同的类别,只有当特定类别作为输入时,对应卷积核输出的特征才能被使用。

  论文的主要贡献如下:

  • 提出新的训练策略来学习更灵活的卷积核与类别的关系,每个卷积核仅提取一个或少量类别的相关特征。
  • 提出通过卷积特征和类别预测的互信息来验证卷积核与类别的关系,并且基于此设计了一个度量方法来测量网络的filter-class entanglement。
  • 通过实验证明论文提出的方法能够消除卷积核的冗余以及增强可解释性,可应用于目标定位和对抗样本检测。

Ideally Class-Specific Filters


  如图2所示,理想的类特定卷积核应该只对应一个类别,为了明确定义,使用矩阵$G\in [0, 1]^{C\times K}$来表示卷积核和类别的相关性,矩阵元素$G^k_c\in [0,1]$代表$k$卷积核和$c$类别的相关性。对于输入样本$(x,y)\in D$,取矩阵$G$的行$G_y \in [0, 1]K$作为控制门,将不相关的卷积核输出置为零。定义$\tilde{y}$为正常网络结构(STD)直接预测的类概率向量,$\tilde{y}G$为加入矩阵$G$(处理倒数第二层的特征图)后的网络(CSG)预测的类概率向量,若存在$G$(所有列为one-hot)使得$\tilde{y}^G$和$\tilde{y}$几乎不存在差异时,称该卷积核为理想的类特定卷积核。

Problem formulation


  为了让网络在训练中分化类特定卷积核,论文在标准的前行推理(standard path, STD)中引入可学习的类特定控制门(Class-Specific Gate path, CSG) ,用来有选择性地阻隔不相关特征维度。

The Original Problem

  如上图所示,论文的目标是训练包含理想类特定卷积核的网络,网络参数为$\theta$,包含两条前向推理路径:

  • 标准路径STD预测$\tilde{y}_{\theta}$。
  • 包含矩阵$G$的类特定门路径(CSG)预测$\tilde{y}^G_{\theta}$。

  CSG将倒数第二层的输出乘以可学习控制门$G_y$,$y$为输入样本的标签。 为了找到准确描述类别与卷积核关系的控制门矩阵$G$,需要在二值空间中搜索$G$使得CSG路径有最好的分类效果,即优化问题$\Phi_0(\theta)=\underset{G}{min}CE(y||\tilde{y}^G_{\theta}), \forall k\in{1,2,\cdots, K}$,$G_k$是one-hot编码,$\Phi_0$用来验证网络中分化的卷积核的性能,将$\Phi_0$加入到训练损失中作为正则化项,得到整体网络的优化目标:

  $CE(y||\tilde{y}_{\theta})$保证准确率,$\lambda_1 \Phi_0(\theta)$引导$G$的稀疏性。但公式1其实是很难优化的,首先很难保证每个卷积核是绝对地只对应一个类别,通常都是多类别共享特征,其次,非连续空间的二值向量很难通过梯度下降优化。

Relaxation

  为了解决上面提到的两个问题,论文将one-hot向量$Gk$放宽为稀疏连续向量$Gk\in [0, 1]C$,约束其包含至少一个等于1的元素($||Gk||_{\infty}=1$)。另外,加入正则项$d(||G||_1, g)$来引导$G$的尽量稀疏,当L1向量范数$||G||_1$小于上界$g$时,则不进行惩罚。$d$的常规设计是$d(a,b)=\psi(ReLU(a-b))$,$\psi$可以是各种范数,包括L1、L2和smooth-L1范数。$g$的设置需满足$g\ge K$,因为$||Gk||_{\infty}=1$,共有K个$Gk$。综合上面的方法,$\Phi_0$重新定义为:

  $V_G={G\in [0,1]^{C\times K}:||G^k||_{\infty}=1}$,$\mu$为平衡因子,$\Phi$可看作是filter-class entanglement的损失函数,将$\Phi$替换公式1的$\Phi_0$得到放松后的完整的优化问题:

  公式3可通过梯度下降联合优化$\theta$和$G$得到类特定卷积核,而且$G$能准确地描述卷积核与类别间的相关性,比优化原本离散的优化问题要简单得多。

Optimization

  针对CSG算法的场景,论文提出PGD(approximate projected gradient descent)梯度下降来解决公式3的优化问题,当$G$进行梯度更新后,$Gk$会通过$||Gk||{\infty}$进行归一化,保证$||G^k||{\infty}=1$,然后裁剪到$[0,1]$。

  由于CSG路径阻隔了大部分的特征,所以CSG路径的梯度回传比STD路径弱很多,如果按正常的方式进行训练,收敛效果会很一般。为此,论文提出alternate training scheme,在不同的周期交替地使用STD/CSG路径的梯度。如算法1所示,在CSG路径的周期,使用梯度$\lambda_1 CE(y||\tilde{y}^G_{\theta})+\lambda_2d(||G||1, g)$更新$G$和$\theta$进行更新,而在STD路径的周期,则使用梯度$CE(y||\tilde{y}{\theta})$进行更新。根据实验验证,这种训练方法在训练初期的分类效果会周期性波动,但最终的训练效果比正常的训练方法要好,同时卷积核也能逐渐分化成类特定卷积核。

Experiment


Quantitative Evaluation Metrics

  论文实验使用了3种指标来验证CSG的有效性:

  • classification accuracy,用来计算分类性能。
  • mutual information score,使用互信息矩阵$M\in \mathbb{R}^{K\times C}$来计算类与卷积核的关系,矩阵元素$M_{kc}=MI(a_k||1_{y=c})$为卷积核$k$的特征值与类别$c$间的互信息。为了计算互信息,在多个数据集中采样$(x,y)$,$a_k$由所有样本的对应输出全局平均池化得来,$1_{y=c}$为类别,$MI$的计算直接调用“sklearn.feature selection.mutual_info_classif”方法。互信息分数$MIS=mean_k max_c M_{kc}$,分数越高,则filter-class entanglement现象越少。
  • L1-density,用来度量CSG的稀疏性,计算方法为$\frac{||G||_1}{KC}$

  结果如表1所示,可以看到CSG网络在分类表现上仅比STD网络要稍微好一点,但其它指标要高出很多。

Visualizing the Gate/MI Matrices

  为了展示卷积核与类别间的相关性,对控制门矩阵$G$和互信息矩阵$M$进行可视化:

  • 图a表明CSG训练能得到稀疏的CSG矩阵,每个卷积核仅对应一个或少量类别。
  • 图b1和b2则表明CSG网络比STD网络有更高的互信息得分。
  • 图c表明图a和图b1的最大元素几乎是重叠的,卷积核能够按照稀疏的CSG矩阵进行学习。

Application

  定位任务上的性能对比,这里的定位是直接通过特征图的大小进行定位,非Faster-RCNN之类的。

  对抗样本检测任务上的性能对比。

Conclustion


  论文提出类特定控制门CSG来引导网络学习类特定的卷积核,并且加入正则化方法来稀疏化CSG矩阵,进一步保证类特定。从实验结果来看,CSG的稀疏性能够引导卷积核与类别的强关联,在卷积核层面产生高度类相关的特征表达,从而提升网络的性能以及可解释性。





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

CSG:清华大学提出通过分化类特定卷积核来训练可解释的卷积网络 | ECCV 2020 Oral的更多相关文章

  1. 腾讯推出超强少样本目标检测算法,公开千类少样本检测训练集FSOD | CVPR 2020

    论文提出了新的少样本目标检测算法,创新点包括Attention-RPN.多关系检测器以及对比训练策略,另外还构建了包含1000类的少样本检测数据集FSOD,在FSOD上训练得到的论文模型能够直接迁移到 ...

  2. CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为了学那个卷积核啊。

    官方参数解释: Convolution 2D tflearn.layers.conv.conv_2d (incoming, nb_filter, filter_size, strides=1, pad ...

  3. CSS 类、伪类和伪元素差别具体解释

    CSS中的类(class)是为了方便过滤(即选择)元素,以给这类元素加入样式,class是定义在HTML文档树中的. 可是这在一些情况下是不够用的,比方用户的交互动作(悬停.激活等)会导致元素状态发生 ...

  4. python类(class)中参数self的解释说明

    python类(class)中参数self的简单解释 1.self只有在类的方法中才会有,其他函数或方法是不必带self的. 2.在调用时不必传入相应的参数.3.在类的方法中(如__init__),第 ...

  5. Python 爬取的类封装【将来可能会改造,持续更新...】(2020年寒假小目标09)

    日期:2020.02.09 博客期:148 星期日 按照要求,我来制作 Python 对外爬取类的固定部分的封装,以后在用 Python 做爬取的时候,可以直接使用此类并定义一个新函数来处理CSS选择 ...

  6. 如何用C++封装一个简单的数据流操作类(附源码),从而用于网络上的数据传输和解析?

    历史溯源 由于历史原因,我们目前看到的大部分的网络协议都是基于ASCII码这种纯文本方式,也就是基于字符串的命令行方式,比如HTTP.FTP.POP3.SMTP.Telnet等.早期操作系统UNIX( ...

  7. Fully Convolutional Networks for Semantic Segmentation 译文

    Fully Convolutional Networks for Semantic Segmentation 译文 Abstract   Convolutional networks are powe ...

  8. 深度学习论文翻译解析(十一):OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks

    论文标题:OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks 标题翻译: ...

  9. face recognition[Euclidean-distance-based loss][FaceNet]

    本文来自<FaceNet: A Unified Embedding for Face Recognition and Clustering>.时间线为2015年6月.是谷歌的作品. 0 引 ...

随机推荐

  1. Idea使用方式——创建类模板

    问题:创建类或接口时,要添加自定义的默认注释,比如版本,时间等.每个类修改显然不符合程序员的思路,有没有办法通过定义模板来实现? 使用Idea模板 Idea可听过创建类模板来实现. 功能路径:Sett ...

  2. 关于bat批处理的一些操作,如启动jar 关闭进程等

    先说一下学习这个的前提: 公司要写个生成uid的工具,整完了之后就又整批处理工具,出于此目的,也是为了丰富自己的知识,就学习了一下,下面是相关的批处理脚本 我花了半天的时间找了相关的bat批处理,但是 ...

  3. 【问题】【SpringBoot】记一次springboot框架下用jackson解析RequestBody失败的问题

    最近项目中遇到了一个问题,费好大劲才发现问题所在,并且修复了问题,下面分享一下这个问题的定位和修复的新路旅程. 先说下背景:该项目用的是SpringBoot框架,主要功能为对外提供一些Restful ...

  4. 04async await

    async async 函数返回值是一个promise对象,promise对象的状态由async函数的返回值决定   //函数的三种定义 async function hello() { return ...

  5. Zabbix Server宕机报“__zbx_mem_malloc(): out of memory (requested 96 bytes)”

    早上登录Zabbix的时候,发现其提示"Zabbix server is not running: the information displayed may not be current& ...

  6. STS 使用lombox.jar

    在Maven本地仓库中找到 将lombox.jar放在与STS.exe平级的目录下, 然后安装完了以后可能会出先打不开的情况.这个时候只要打开STS.ini文件. 然后修改文件保存

  7. 关于Vue-Router的那些事儿

    Vue-Router Vu-router是Vue.js的官方路由管理器,方便于构建页面,在我们的使用过程中,一般是在router的index.js文件中配置对应的组件和对应的路径 主要的功能 嵌套路由 ...

  8. SpringBoot写后端接口,看这一篇就够了!

    摘要:本文演示如何构建起一个优秀的后端接口体系,体系构建好了自然就有了规范,同时再构建新的后端接口也会十分轻松. 一个后端接口大致分为四个部分组成:接口地址(url).接口请求方式(get.post等 ...

  9. Centos7源码编译安装LAMP环境

    参考地址:https://www.linuxidc.com/Linux/2018-03/151133.htm

  10. defer implement for C/C++ using GCC/Clang extension

    前述: go 中defer 给出了一种,延时调用的方式来释放资源.但是对于C/C++去没有内置的这种属性.对于经常手动管理内存的C/C++有其是C程序员这种特性显得无比重要.这里给出了一种基于GCC/ ...