Key_words: Continual learning, strong pretrained model, fix, fusion
Create_time: April 14, 2022 6:32 PM
Edited_by: Huang Yujun
Publisher: CVPR 2022
Score /5: ️
Status: Finished
Org: AWS AI Labs

[29]Class-Incremental Learning with Strong Pre-trained Models.pdf

1. Motivation

  • 目前的研究并未考虑新类之间重叠(具有相同类标签)的情况,默认序列间类别不重叠,但这与实际情况不符
  • 初始阶段使用大量类别数据初始化特征提取器,从而获得一个学习到丰富特征的特征提取器

本文认为训练得不错的特征提取器就已经能够在新类上表现很好,因此并不需要在每次学习新类时fintune整个网络,上图为本文所做假设的验证性实验。

左图,实验设置为,分别使用100, 200, ... ,800 个类别数据作为初始类预训练一个特征提取器后

  1. 固定这个特征提取器并训练一个新类的全连接分类器
  2. 微调整个特征提取器及全连接分类器

对上述两种设置进行对比,结果发现固定特征提取器的新类分类准确类,随初始化类别增多而升高,说明初始类别数越多越有助于新类的学习。

右图,为探索 仅finetune 对应层+初始阶段类别变化 对模型性能的影响。本文发现,初始阶段类别数非常多时,是否 finetune 高层对最终准确率影响不大

此外,本文预设了一种新的场景,即每个批次的类别数据中,可能包含类标号一样的数据。如第 i 批次数据出现过类别为”狗“的类,在后续批次数据中可能再次出现。

2. Contribution

  • 提出了一种在新场景下(不同批次数据之间,旧类可能会再次出现)的持续学习方法
  • 首次对初始阶段的训练进行研究
  • 提出了一种能够融合各个阶段特征提取器特征的分类器

这篇工作并没有在结构或者损失函数上有跨越性的创新,但提出了一种比较有意思的解决方向。

3. Methodology

模型训练具体可以分为4个阶段:

  • 初始阶段预训练得到一个学习过丰富特征的特征提取器(使用 ImageNet 中的800个类)
  • 训练各阶段新类数据的特征提取器,同时学习如何融合对特征提取器的输出
  • 如果发现序列数据中含学习过的旧类数据,通过 knowledge pooler 合并对应的特征

3.1 Pretrain Stage

本文的 backbone 包含两部分,一部分是 \(\phi_{s}\) 为 Resnet10(第1~3个block)使用800个类训练好后固定,另外一部分是并行连接的特征提取器 \(\phi_{b}\)(第4个block,每个批次数据对应一个,通过复制上一阶段的参数做新阶段参数的初始化)。

3.2 Training pipeline

Stage-I Feature augmentation(FA)

这样通过多个阶段学习,可以得到网络中具有多条分支的第4层参数集合:\(\{\Phi_{b},W_{b},\Phi_{n1},W_{n1},...,\Phi_{nT},W_{nT}\}\) ,其中 b 表示base,是初始阶段参数,nT表示第T个阶段,\(\Phi\)表示特征提取器(最后一层),\(W\) 表示各阶段对应的全连接分类器。后续的讨论,都基于这个特征提取器,后续训练均是在 freeze 这个特征提取器的条件下完成的。

Stage-II Fusion

这一阶段需要解决的问题是如何设计 Figure 3 中对输出特征向量进行融合的网络结构(途中打问号的区域)。针对这个问题,本文探索了两种 baseline 用作对比,即 Figure 4 中的(a)(b)。这两种 baseline 想解决的都是如何选择哪条通路的输出作为最终输出的问题-->是属于 base 的数据\(D_{b}\)(800个初始类),还是属于novel 的数据 \(D_{n}\)(后续学习到的类,文中的例子是T=1,即只有初始阶段和第一个阶段)。

\[\hat{y}_{d}=\hat{r}(x)=argmax_{l}\hat{p}^{(l)}(x;\Phi_{s},\Phi_{d},W_{d}),d\in \{b,n\}
\]

式子中 \(\hat{r}(x)\) 为分支选择函数,输出为 0 即选择base,,输出为 1 即选择 novel

Figure 4 (a)为文中提到的 Confidence-based routing,判断方式为通过对各个独立分类头输出的类别置信度进行比较,选出最大的,从而确定是属于 base 还是 novel,数学表达为

Figure 4 (b)为文中提到的 Learning-based routing,判断方式为对输出的两个特征向量进行拼接,然后使用一个全连接分类器学习如何区分 base 和 novel。数学表达为:(\(\oplus\)为concat操作)

考虑到学习如何区分新旧类时,数据中存在类别样本间的不均衡,本文针对存储样本以及新类样本的损失函数为:


其中,公式 (4) 是 binary cross-enctropy (不同于 cross-entropy 每一类对应输出均会产生loss输出,即同一时刻所有输出对应的通路都可更新,而 CE loss 只有一条通路可更新),\(r=1_{[x\in D_{n}]}\) 是 x 所对应的 onehot label;公式(5) 中 \(\varepsilon\) 为存储样本。考虑到新类旧类数据的不均衡,即 \(|D_{n}| \gg |\varepsilon|\) ,本文对loss进行了均衡化处理。

3.3 General score fusion network

经过上面2种 baseline 的对比测试后,作者提出了一种融合各个高维特征提取器输出的方法,示意图见 Figure 4 (c)(注意,本文是保留各个特征提取器的分类头的)。具体操作要点是:

  • 固定特征提取器的参数 \(\{\Phi_{s},\Phi_{b},W_{b},\Phi_{n1},W_{n1},...,\Phi_{nT},W_{nT}\}\)
  • 模型最终做推断时,只需要使用各个特征提取器的全连接输出logits socore \(z_{d}=W_{d}^{T}h_{d}, W_{d}\in R^{k\times |y_{d}|}\) concat到一起,然后做softmax即可得到各个类别的输出概率
  • 为了能够让各个知识能够在各个分支融合,作者提出使用 \(\varepsilon \cup D_{nt}\) 来学习各个输出 logits 之间的关联知识权重 \(W_{dd'}\in R^{k\times |y_{d}|},d,d'\in \{b,n1,...,nt\},d\neq d'\) 。 \(W_{dd'}\) 表示其为连接第 d 个分支的特征到第 d’ 个分支。d 分支输出与其他分支输出融合的方式为直接相加。融合过程数学表达为

  • 融合后的各分支输出直接 concat 形成一个完整的输出logits,从而完成推断的工作
\[z_{a}=\tilde{z}_{b}\oplus \tilde{z}_{n1} \oplus \cdot \cdot \cdot \oplus \tilde{z}_{nt}, \tilde{z}_{d}\in R^{|y_{d}|},d\in\{b,n1,...,nt\}
\]

Overlap knowledge integration

针对本文前面提到的新旧批次数据中类别 overlap 的情况,本文方法采取的策略是:直接对相同类别对应的输出做 pooling(average pooling 或 max pooling,实验结果表明 average pooling 比 max pooling效果好)见图 Figure 4 (c)

经过Pooling后的 logits 记为 \(\bar{x}\) ,当且仅当 \(y_{d}\cap y_{d'}=\emptyset\) 时, \(\tilde{z}_{a}=z_{a}\) 成立。

3.3 Balanced optimization

在最终得到一个融合后的拼接向量后,本文方法会 freeze 整个特征提取部分,单独训练全连接分类器,此时,的损失函数为公式(8):

这里同样考虑到了类别样本均衡(如\(|D_{nt}|\gg|\varepsilon|\)),公式中 \(B \in \varepsilon \cup D_{nt}\) 为分别从存储类别数据及最后一个阶段新类数据中采样的均衡数据集,损失函数为交叉熵。

分支输出选择函数可表示为:

其中,\(W_{r,aux}\in R^{(t+1)\times (t+1)}\) 为一个全连接路径分类器的参数。

最终的 loss 函数可表示为:

此外,为了防止训练过程中过度倾向于 base classes,作者会对 \(h_{d}\) 做 normalize 以及 scale。但需注意的是,对于 base 的logits,为了防止融合时过去倾向于base,文中还设置了一个超参数 \(\beta \in [0,1]\) 用于调节 base logits 的值(base logits 先乘以这个系数后在进行fusion操作)

4. Experiments

本文中只用到了一个 ImageNet1000 这个自然图像数据集(说服力有点弱),

4.1 参数\(\alpha,\beta\) 的灵敏度实验

4.2 与其他现有方法的对比(无 Overlap)

本文使用的指标为:

  • \(Acc_{all}\) ,最后一个阶段结束后,在所有类别数据上的acc
  • \(Acc_{base}\),第一阶段 800 个类的acc
  • \(Acc_{novel}\),新类的acc
  • \(Acc_{ovlp}\),overlap 部分数据的acc
  • \(Acc_{avg}=\frac{\sum_{d\in\{b,n1,...,nt\}}Acc_{d}}{t+1}\),

说明一下,文中按照 训练集、验证集、测试集 的方式划分数据,因此本文的参数是在验证集中挑选最好的产生。因此就需要选定一个性能指标去挑选验证集上“最好”的模型。下表中的 \(best-Acc_{all},best-balanced,best-Acc_{avg}\) 分别表示使用 \(Acc_{all},balanced=\frac{Acc_{all}+Acc_{avg}}{2},Acc_{avg}\) 指标下选择的模型。

Table 1 为只有一个阶段的新类的结果(800base+novel)

其中, joint learning(oracle) 指使用所有base类进行训练();“fc-only”是在分类器部分只用了全连接层,设置这组实验的目的是为了保证本方法的参数与其他对比方法的参数量接近。(奇怪?)

Table 2 为有多个含新类阶段的结果(800base+novel)

4.2 Fusion策略的对比实验(无overlap)

下表中的标识为:

  • FA:本文提出的 feature augmentaion(即初始阶段使用800个类训练)
  • RT:retrain
  • FT:fineture
  • FeatCat+RT:重新训练一个输入为 将所有特征concat在一起 的全连接分类器
  • LogitCat:重新训练(RT)或微调(FT)一个将 logits 拼接的全连接分类器

4.3 Base 特征提取器层数的影响(无Overlap)

4.4 Overlap情况下的表现

4.5 Base中包含所有新类的情况

【论文阅读笔记】Class-Incremental Learning with Strong Pre-trained Models的更多相关文章

  1. 论文阅读笔记 Improved Word Representation Learning with Sememes

    论文阅读笔记 Improved Word Representation Learning with Sememes 一句话概括本文工作 使用词汇资源--知网--来提升词嵌入的表征能力,并提出了三种基于 ...

  2. [论文阅读笔记] metapath2vec: Scalable Representation Learning for Heterogeneous Networks

    [论文阅读笔记] metapath2vec: Scalable Representation Learning for Heterogeneous Networks 本文结构 解决问题 主要贡献 算法 ...

  3. [论文阅读笔记] node2vec Scalable Feature Learning for Networks

    [论文阅读笔记] node2vec:Scalable Feature Learning for Networks 本文结构 解决问题 主要贡献 算法原理 参考文献 (1) 解决问题 由于DeepWal ...

  4. [论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks

    [论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks 本文结构 解决问题 主要贡献 算法原理 参考文献 (1) 解决问 ...

  5. [论文阅读笔记] Adversarial Mutual Information Learning for Network Embedding

    [论文阅读笔记] Adversarial Mutual Information Learning for Network Embedding 本文结构 解决问题 主要贡献 算法原理 实验结果 参考文献 ...

  6. [置顶] 人工智能(深度学习)加速芯片论文阅读笔记 (已添加ISSCC17,FPGA17...ISCA17...)

    这是一个导读,可以快速找到我记录的关于人工智能(深度学习)加速芯片论文阅读笔记. ISSCC 2017 Session14 Deep Learning Processors: ISSCC 2017关于 ...

  7. Nature/Science 论文阅读笔记

    Nature/Science 论文阅读笔记 Unsupervised word embeddings capture latent knowledge from materials science l ...

  8. 论文阅读笔记 - YARN : Architecture of Next Generation Apache Hadoop MapReduceFramework

    作者:刘旭晖 Raymond 转载请注明出处 Email:colorant at 163.com BLOG:http://blog.csdn.net/colorant/ 更多论文阅读笔记 http:/ ...

  9. 论文阅读笔记 - Mesos: A Platform for Fine-Grained ResourceSharing in the Data Center

    作者:刘旭晖 Raymond 转载请注明出处 Email:colorant at 163.com BLOG:http://blog.csdn.net/colorant/ 更多论文阅读笔记 http:/ ...

  10. 论文阅读笔记 Word Embeddings A Survey

    论文阅读笔记 Word Embeddings A Survey 收获 Word Embedding 的定义 dense, distributed, fixed-length word vectors, ...

随机推荐

  1. 动手搭建ssm框架

    现在很多公司用的开源框架很多都是ssm框架的一个结构,这里我自己试着自己搭一个简单的框架,大家共同学习.下面一起跟着我搭建吧,本人菜鸟,有任何不对的地方有望指出. 框架结构:spring(4.3.9. ...

  2. Mysql explain 每个属性含义

    Mysql explain explain 常用于分析sql语句的执行效率,使用时在正常的select语句之前添加explain并执行就会返回执行信息,返回的执行信息如下:  id:id列的编号是se ...

  3. 学习JavaScript第五周

    MySQL基本内容: 访问:2种 ​ 1.图形化界面 - 傻瓜式 ​ 要求:同时打开apache和mysql ​ 访问:127.0.0.1:端口号/phpmyadmin ​ localhost:端口号 ...

  4. CSS 常用样式-字体属性

    字体类样式我们已经学习过字号font-size.字体font-family两个属性,接下来还有几个常用的字体属性. 粗细 font-weight: 作用:设置文字是否加粗显示. 属性名:font-we ...

  5. 11.30linux学习第十一天

    今天老刘上课,第7章收尾,第8章开了个头. 7.1.3  磁盘阵列+备份盘 RAID 10磁盘阵列中最多允许50%的硬盘设备发生故障,但是存在这样一种极端情况,即同一RAID 1磁盘阵列中的硬盘设备若 ...

  6. 图模配置文件之 flow.json

    flow.json文件是用来配置图模导入时,各种不同的图模导入时,分别应该使用哪个映射文件对模型进行处理.在不同地区使用不同的格式的图模文件时,需要修改flow.json中相关的配置,来适应相应的图模 ...

  7. Otto Group Product Classification

    遇到的坑: 做多分类,用CrossEntropyLoss时,训练时候的正确标签的范围应该是[0,n-1],而不是[1,n],不然会报 IndexError: Target is out of boun ...

  8. SQL Server获取连接的IP地址

    来源:http://www.itpub.net/thread-193247-1-1.html 先保存,以后研究一下 1 *--获取连接SQL服务器的信息 2 3 所有连接本机的:操作的数据库名,计算机 ...

  9. 解决和根源:Unsolicited response received on idle HTTP channel starting with xxx

    环境:golang,使用http client,服务器:iis +aspx.net动作:head请求或其他此问题见于各种请求情况.核心是,http在活动期间收到了非预期的信息.一开始我也很纳-闷,因为 ...

  10. Httpt请求

    在c#中常见发送http请求的方式如下 HttpWebRequest: .net 平台原生提供,这是.NET创建者最初开发用于使用HTTP请求的标准类.使用HttpWebRequest可以让开发者控制 ...