ViT

概括

论文题目:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

论文地址:https://openreview.net/pdf?id=YicbFdNTTy

作者来自 Google

亮点:

  • 一些有趣的特性:

    • CNN 处理不太好但是 ViT 可以处理好的例子:
    • 遮挡
    • 数据分布偏移
    • 加入对抗性的 patch
    • 排列

作者认为:

  • 对于 CNN 的依赖是不必要的
  • 纯 Transformer 可以做到和 CNN 媲美的结果
  • Transformer 需要更少的训练资源,即使如此,也需要 2500 TPUv3 天数。这里说的少,只是跟更耗卡的模型做对比。
  • 在 CV 使用 Transformer
    • 难点:

      • 像素点过多,而 maxlength 太短
    • 于是前人提出许多思路,降低 length 长度:
      • 用 ResNet 最后的特征图 \(14 \times 14=196\) 输入 transformer
      • 用局部小窗口或者把图像这个二维的拆成两个一维的向量
      • 没有硬件加速,模型都无法做到太大
    • 大规模还是 ResNet 效果最好
  • 于是 ViT 做法:
    • 模型

      • 把一张图片分成很多 patch,每一个 patch 的大小为 16 * 16
      • 由于 224 / 16 = 14,因此共有 14 * 14 个 patch
      • 所以 length 为 14 * 14 = 196
      • 然后把每一个 patch 通过一个 linear embedding,这些再作为输入传给 Transformer
    • 训练
      • 用有监督的方式训练,原因是 cv 还是需要有监督的
    • 有钱
      • 之前有一个思路一样的,但是那个作者没钱
  • 一些结果
    • 中型大小数据集上,ViT 比同等大小的 ResNet 要弱几个点,作者认为原因有

      • Transformer 模型缺少一些 CNN 的归纳偏置

        • 局部临近 locality

          • 相邻的图片有相邻的特征
        • 平移等变性 translation equivariance:
          • \(f(g(x)) = g(f(x))\)
          • 即函数顺序变换最后结果也一样。
          • 在 CNN 中,即为相同的物体无论平移到哪里,只要遇到相同的卷积核,那么输出一样。
      • 因此 Transformer 缺少一些 CNN 拥有的前置知识,需要自己从数据里学。
    • 于是作者又在大规模的数据集上学习,效果很好,得到与 ResNet 相近或者更好的结果。

模型

模型介绍

  1. 首先给定一张图,然后把这张图打成许多个 patch
  2. 然后每个 patch 通过一个线性的投射层得到一个特征。
  3. 再通过 Patch + Position Embedding 的方式,把位置编码弄上去。
  4. 丢进 TFM enc。
  5. 然后拿 [CLS] 最后对应的 representation 丢进一个 MLP Head 里。

图像的维度从 \(224 \times 224 \times 3\),变成了一个有 196 个有的 \(16 \times 16 \times 3 = 768\) 维度的 patch。

线性投射层是一个全连接层,维度是 \(D \times D = 768 \times 768\).

然后要加上 [CLS] token

最后加上位置编码信息

  • 具体是把位置编码编成一张表,每一个位置对应一个向量
  • 位置编码,三种效果差不多
    • 作者做了 1D 的位置编码。常规方法。
    • 2D 的做法:假设之前 1D 的位置编码维度是 D,2D 的位置编码横坐标有 \(\frac{D}{2}\) 的维度,纵坐标亦然,然后直接拼在一起。
    • 相对位置编码:两个 patch 之间的距离可以用相对距离来表示

归纳偏置,ViT 比 CNN 少了很多归纳偏置

  • CNN 中,局部性和平移等变性在模型每一层都有所体现,因此先验知识贯穿模型始终。
  • 而 ViT 中,只有 MLP 这个层有局部性和平移等变性,自注意力层是全局的。

作者还做了一个混合的网络,前面 CNN,后面 TFM

预训练以及更大的图片

图片更大,patch 的个数也变了,于是位置编码也会变。

  • 作者直接用了 pytorch 官方自带的 interpolate 函数做 2D 插值。
  • 这个插值只能算一个临时解决方案,是 ViT 的局限性。

实验

主要对比 ResNet,ViT,hybrid

下游任务主要是分类

实验结果

ImageNet 结果

  • 中型大小数据集上,ViT 比同等大小的 ResNet 要弱几个点
  • 大型数据集上,ViT 几乎全面超过 ResNet

线性 5-shot 评估

  • 没有经过微调
  • 结果同上

  • 同等预训练计算复杂度下,ViT 比 ResNet 强
  • 预训练计算次数小时,混合模型最强
  • 数据越来越大时,ViT 越来越强,接近超越混合模型和 ResNet
  • ViT 和 ResNet 模型似乎都没有饱和,仍然可以继续往上走

看李沐的 ViT 串讲的更多相关文章

  1. CLIP改进工作串讲(上)学习笔记

    看了跟李沐学AI系列朱毅老师讲的CLIP改进工作串讲,这里记录一下. 1.分割 分割的任务其实跟分类很像,其实就是把图片上的分类变成像素级别上的分类,但是往往图片上能用的技术都能用到像素级别上来.所以 ...

  2. 视频+图文串讲:MySQL 行锁、间隙锁、Next-Key-Lock、以及实现记录存在的话就更新,如果记录不存在的话就插入如何保证并发安全

    导读 Hi,大家好!我是白日梦!本文是MySQL专题的第 27 篇. 下文还是白日梦以自导自演的方式,围绕"如何实现记录存在的话就更新,如果记录不存在的话就插入."展开本话题.看看 ...

  3. 0607pm克隆&引用类&加载类&面向对象串讲&函数重载

    克隆class Ren{ public $name; public $sex; function __construct($n,$s) { $this->name=$n; $this->s ...

  4. CLIP 改进工作串讲(下)学习笔记

    1.图像生成 1.1CLIPasso(semantically-aware object sketching) 将物体的照片变成简笔画的形式,希望即使有最少的线条,也能识别出来物体. 问题定义,在纸上 ...

  5. getElementById返回的是什么?串讲HTML DOM

    1. getElementById()返回的是什么? 这个函数使用的最普遍,但是你有没有深入探究下,这个函数究竟返回的是什么么?我们来一起看看. var mydivEle = document.get ...

  6. 集成学习-Boosting 模型深度串讲

    首先强调一下,这篇文章适合有很好的基础的人 梯度下降 这里不系统讲,只介绍相关的点,便于理解后文 先放一个很早以前写的 梯度下降 实现 logistic regression 的代码 def tidu ...

  7. 全网最牛X的!!! MySQL两阶段提交串讲

    目录 一.吹个牛 二.事务及它的特性 三.简单看下两阶段提交的流程 四.两阶段写日志用意? 五.加餐:sync_binlog = 1 问题 六.如何判断binlog和redolog是否达成了一致 七. ...

  8. 全网最清楚的:MySQL的insert buffer和change buffer 串讲

    目录 一.前言 二.问题引入 2.1.聚簇索引 2.2.普通索引 三.change buffer存在的意义 四.再看change buffer 五.change buffer 的限制 六.change ...

  9. .NET 基础串讲

    C#基础 .NET介绍 —计算机发展史 第一代语言:机器语言 0101 第二代语言:汇编语言, 用一些简洁的英文字母.符号串来替代一个特定指令的二进制串 第三代语言:接近于数学语言或人的自然语言,同时 ...

  10. 技术串讲 CAS 有用

    CAS,全称为Compare and Swap,即比较-替换.假设有三个操作数:内存值V.旧的预期值A.要修改的值B,当且仅当预期值A和内存值V相同时,才会将内存值修改为B并返回true,否则什么都不 ...

随机推荐

  1. 使用组合逻辑电路驱动VGA显示器

    使用组合逻辑电路驱动VGA显示器 1. 概述 本文讲述一种不使用缓冲存储器驱动VGA显示的简单方法.其中,VGA分辨率采用DE10-Lite建议使用的640X480.像素的时钟25MHz,刷新率59. ...

  2. JAVA下唯一一款搞定OLTP+OLAP的强类型查询这就是最好用的ORM相见恨晚

    JAVA下唯一一款搞定OLTP+OLAP的强类型查询这就是最好用的ORM相见恨晚 介绍 首先非常感谢 FreeSQL 提供的部分源码,让我借鉴了不少功能点,整体设计并没有参考FreeSQL(因为jav ...

  3. 如何在局域网内两台电脑上进行webapi的在线调试

    原文地址:https://www.zhaimaojun.top/Note/5475298(我自己的博客) 局域网内WebApi的远程调试方法: 第一步:管理员方式运行Vs并打开需要运行的项目,如果已经 ...

  4. 前端调用DRI后端API出现跨域资源共享(CORS)问题解决办法

    目录 1. 引言 2. 跨源资源共享和实现方法 3. 在Django项目中配置django-cors-headers库 Reference 1. 引言 在进行后端API开发时,有时会遇到"跨 ...

  5. n个人围成一圈,顺序排号从1到n。从第一个人开始报数(从一到三如此循环)。凡是报到三的出局,最后剩下的一个人原始编号为?

    #include<stdio.h> int main(){ int num,n,i=0,flag=0; //num记录剩余人数,n记录总人数,i为原始编号,flag为编号123时的编号 p ...

  6. Vben-admin---ApiSelect Invalid prop: type check failed for prop "onUpdate:value". Expected Function, got Array

    在basicFrom组件里添加一个ApiSelect, <template #localSearch="{ model, field }"> <ApiSelect ...

  7. Python语言:散修笔记

    文章目录 前言 转义字符的使用 原字符 变量的定义 类型转换 注释 接收用户信息 运算规则 整除运算 幂运算 比较运算符 布尔运算 运算优先级 对象的布尔值 if else elif分支结构 条件表达 ...

  8. 使用interface化解一场因操作系统不同导致的编译问题

    场景描述 起因: 因项目需求,需要编写一个agent, 需支持Linux和Windows操作系统. Agent里面有一个功能需要获取到服务器上所有已经被占用的端口. 实现方式:针对不同的操作系统,实现 ...

  9. 给大家分享一套非常棒的python机器学习课程

    给大家分享一套非常棒的python机器学习课程--<AI小天才:让小学生轻松掌握机器学习>,2024年5月完结新课,提供配套的代码+笔记+软件包下载!学完本课程,可以轻松掌握机器学习的全面 ...

  10. c++ RTTI Runtime Type Identification 运行阶段类型识别

    NoVirtualBase* NvirBase = new NovirtualDerivd(); NvirBase->print(); // auto nd1 = dynamic_cast< ...