Neyshabur B., Sedghi H., Zhang C. What is being transferred in transfer learning?

arXiv preprint arXiv 2008.11687, 2020.

迁移学习到底迁移了什么?

主要内容

  • T: 普通训练的模型
  • P: 预训练的模型
  • RI: 随机初始化的模型
  • RI-T: 随机初始化再经过普通训练的模型
  • P-T: 在预训练的基础上再fine-tuning的模型

本文的预训练都是在ImageNet上, 然后在CheXpert和DomainNet(分为real, clipart, quickdraw)上测试.

feature reuse

大家认为迁移学习有用的一个直觉就是迁移学习通过特征的复用来样本少的数据提供一个较好的特征先验.



通过上面的图可以看到, P-T总是能够表现优于RI-T, 这能够支撑我们的观点. 但是, 为什么数据差别特别大的时候, 预训练还是有用呢(此时feature reuse的作用应该不是很明显)? 作者将图片按照不同的block size打乱(就像最开始的那些乱七八糟的图片). 这个时候, 模型应该只能抓住浅层的特征, 抽象的特征是没法被很好提取的, 结果如下图所示.

  • 当打乱的程度加剧(block size变小), 任务越发困难;
  • 相对正确率差距\((A_{P-T}-A_{RI-T})/A_{P-T} \%\)随着block size减小而减小(clipart, real), 这说明feature reuse很有效果, quickdraw 相反是由于其数据集和预训练的数据集相差过大, 但是即便如此, 在quickdraw上预训练还是有效的, 说明存在除了feature reuse外的因素;
  • P-T的训练速度(右图)一直很稳定, 而RI-T的训练速度则在block size下降的时候有一个急剧的下降, 这说明feature reuse并不是影响P-T训练速度的主要因素.

mistakes and feature similarity

这部分通过探究不同模型有哪些common和uncommon的mistakes来揭示预训练的作用.

P-T在简单样本上的成功率很高, 而在比较模糊难以判断的样本上比较难(而此时RI-T往往比较好), 这说明P-T有着很强的先验.

通过 centered kernel alignment (CKA) 来衡量特征之间的相似度:



可以发现, 基于预训练的模型之间的特征相似度很高, 而RI-T与别的模型相似度很低, 即便是两个相同初始化的RI-T. 说明预训练模型之间往往是在重复利用相同的特征.

下表为不同模型的参数的\(\ell_2\)距离, 同样能够反映上面一点.

loss landscape

用\(\Theta, \tilde{\Theta}\)表示两个checkpoint的参数, 通过线性插值

\[\{\Theta_{\lambda} = (1- \lambda) \Theta + \lambda \tilde{\Theta}: \lambda \in [0, 1]\},
\]

考量模型在\(\Theta_{\lambda}\)下的表现.



上图, 左为DomainNET real, 右为quickdraw, 可见预训练模型之间的loss landscape是很光滑的, 不同于RI-T.

module criticality

如果我们将训练好后的模型的某一层参数替换为其初始参数, 然后观察替换前后的正确率就能一定程度上判断这个层在整个网络中的重要性, module criticality就是一个这样的类似的指标.

下图反映了不同模型的不同层的criticality.

下图反映了RI-T的训练后的参数\(\theta\)其实加了扰动反而性能更好? 而P-T的就相当稳定.

pre-trained checkpoint

我们选pre-trained模型的时候, 往往是通过正确率指标来判断的, 但是事实上, 这个判断并不十分准确, 事实上我们可以早一步地选取checkpoint (直观上理解, 大概是只要参数进入了那个光滑的盆地就行了).

What is being transferred in transfer learning?的更多相关文章

  1. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  2. 迁移学习( Transfer Learning )

    在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...

  3. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

  4. 迁移学习(Transfer Learning)(转载)

    原文地址:http://blog.csdn.net/miscclp/article/details/6339456 在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型 ...

  5. Transfer learning across two sentiment classes using deep learning

    用深度学习的跨情感分类的迁移学习 情感分析主要用于预测人们在自然语言中表达的思想和情感. 摘要部分:two types of sentiment:sentiment polarity and poli ...

  6. 读论文系列:Deep transfer learning person re-identification

    读论文系列:Deep transfer learning person re-identification arxiv 2016 by Mengyue Geng, Yaowei Wang, Tao X ...

  7. 迁移学习-Transfer Learning

    迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...

  8. CVPR2018: Unsupervised Cross-dataset Person Re-identification by Transfer Learning of Spatio-temporal Patterns

    论文可以在arxiv下载,老板一作,本人二作,也是我们实验室第一篇CCF A类论文,这个方法我们称为TFusion. 代码:https://github.com/ahangchen/TFusion 解 ...

  9. pytorch例子学习——TRANSFER LEARNING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 以下是两种主要的迁移学习场景 微调convnet : ...

随机推荐

  1. 学习Java 2021/10/7

    java重写Override 重载Overload 重写是子类对父类的允许访问的方法的实现过程进行重新编写,返回值和形参都不能改变.即外壳不变,核心重写 重写规则: 参数列表与被重写方法的参数列表必须 ...

  2. Mybatis相关知识点(一)

    MyBatis入门 (一)介绍 MyBatis 本是apache的一个开源项目iBatis, 2010年这个项目由apache software foundation 迁移到了google code, ...

  3. Oracle数据库导入与导出方法简述

    说明: 1.数据库数据导入导出方法有多种,可以通过exp/imp命令导入导出,也可以用第三方工具导出,如:PLSQL 2.如果熟悉命令,建议用exp/imp命令导入导出,避免第三方工具版本差异引起的问 ...

  4. IOS_UIButton去掉系统的按下高亮置灰效果

    setAdjustsImageWhenHighlighted   // default is YES. if YES, image is drawn darker when highlighted(p ...

  5. 【Services】【Web】【tomcat】配置tomcat支持https传输

    1. 基础: 1.1. 描述:内网的tomcat接到外网nginx转发过来的请求之后需要和外网的客户端进行通讯,为了保证通讯内容的安装,使用tomcat使用https协议. 1.2. 链接:http: ...

  6. 通过SSE(Server-Send Event)实现服务器主动向浏览器端推送消息

    一.SSE介绍 1.EventSource 对象 SSE 的客户端 API 部署在EventSource对象上.下面的代码可以检测浏览器是否支持 SSE. if ('EventSource' in w ...

  7. 1945-祖安 say hello-String

    1 #define _CRT_SECURE_NO_WARNINGS 1 2 #include<bits/stdc++.h> 3 char str[100][40]; 4 char s[10 ...

  8. java多线程4:volatile关键字

    上文说到了 synchronized,那么就不得不说下 volatile关键字了,它们两者经常协同处理多线程的安全问题. volatile保证可见性 那么volatile的作用是什么呢? 在jvm运行 ...

  9. Python 的元类设计起源自哪里?

    一个元老级的 Python 核心开发者曾建议我们( 点击阅读),应该广泛学习其它编程语言的优秀特性,从而提升 Python 在相关领域的能力.在关于元编程方面,他的建议是学习 Hy 和 Ruby.但是 ...

  10. FlashFXP链接到服务器上,如果www目录下的文件隐藏

    FlashFXP链接到服务器上,如果www目录下的文件隐藏,那么请按照如下设置,就可以显示隐藏的文件了 [站点]->[站点管理器]->选项,然后按照如下设置: