• 文章原创自:微信公众号「机器学习炼丹术」
  • 作者:炼丹兄
  • 联系方式:微信cyx645016617

  • 代码来自github

【前言】:看代码的时候,也许会不理解VIT中各种组件的含义,但是这个文章的目的是了解其实现。在之后看论文的时候,可以做到心中有数,而不是一片茫然。

VIT类

初始化

和之前的学习一样,从大模型类开始看起,然后一点一点看小模型类:

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool
self.to_latent = nn.Identity() self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

在实际的调用中,是如下调用的:

model = ViT(
dim=128,
image_size=224,
patch_size=32,
num_classes=2,
channels=3,
).to(device)

输入参数讲解:

  • image_size:图片的大小;
  • patch_size:把图片划分成小的patch,小的patch的尺寸;
  • num_classes:这次分类任务的类别总数;
  • channels:输入图片的通道数。

VIT类中初始化的组件:

  • num_patches:一个图片划分成多少个patch,因为图片224,patch32,所以划分成7x7=49个patches;
  • patch_dim:3x32x32,理解为一个patch中的元素个数;

......这样展示是不是非常的麻烦,还要上下来回翻看代码,所以我写成注释的形式

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
# image_size=224,patch_size=32,num_classes=2,channels=3,dim=128
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
# num_pathes = (224//32)**2 = 7*7=49
num_patches = (image_size // patch_size) ** 2
# patch_dim = 3*32*32
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# self.patch_size = 32
self.patch_size = patch_size
# self.pos_embedding是一个形状为(1,50,128)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# self.patch_to_embedding是一个从3*32*32到128映射的线性层
self.patch_to_embedding = nn.Linear(patch_dim, dim)
# self.cls_token是一个随机初始化的形状为(1,1,128)这样的变量
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) # Transformer后面会讲解
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool
self.to_latent = nn.Identity() self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

forward

现在看VIT的推理过程:

    def forward(self, img, mask = None):
# p=32
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x) # x.shape=[b,49,128]
b, n, _ = x.shape # n = 49 cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1) # x.shape=[b,50,128]
x += self.pos_embedding[:, :(n + 1)] # x.shape=[b,50,128]
x = self.dropout(x) x = self.transformer(x, mask) # x.shape=[b,50,128],mask=None x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x)
return self.mlp_head(x)
  • 这里的代码用到了from einops import rearrange, repeat,这个库函数,einops是一个库函数,是对张量进行操作的库函数,支持pytorch,TF等。
  • einops.rearrange是把输入的img,从[b,3,224,224]的形状改成[b,3,7,32,7,32]的形状,通过矩阵的转置换成[b,7,7,32,32,3]的样子,最后合并成[b,49,32x32x3]
  • self.patch_to_embedding,输出的x的形状为[b,49,128];
  • einops.repeat是把self.cls_token从[1,1,128]复制成[b,1,128]

现在,我们知道从patch到embedding是用线性层实现的。

transformer

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
# dim=128,depth=12,heads=8,dim_head=64,mlp_dim=128
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
  • self.layers中包含depth组的Attention+FeedForward模块。
  • 这里需要记得,输入的x的尺寸为[b,50,128]

Attention

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads # 64 x 8
self.heads = heads # 8
self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads # n=50,h=8
# self.to_qkv(x)得到的尺寸为[b,50,64x8x3],然后chunk成3份
# 也就是说,qkv是一个三元tuple,每一份都是[b,50,64x8]的大小
qkv = self.to_qkv(x).chunk(3, dim = -1)
# 把每一份从[b,50,64x8]变成[b,8,50,64]的形式
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# 这一步不太好理解,q和k都是[b,8,50,64]的形式,50理解为特征数量,64为特征变量
# dots.shape=[b,8,50,50]
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 不考虑mask这一块的内容
mask_value = -torch.finfo(dots.dtype).max if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# 对[b,8,50,50]的最后一个维度做softmax
attn = dots.softmax(dim=-1) # 这个attn就是计算出来的自注意力值,和v做点乘,out.shape=[b,8,50,64]
out = torch.einsum('bhij,bhjd->bhid', attn, v)
# out.shape变成[b,50,8x64]
out = rearrange(out, 'b h n d -> b n (h d)')
# out.shape重新变成[b,60,128]
out = self.to_out(out)
return out

综上所属,这个attention其实就是一个自注意力模块,输入的是[b,50,128],返回的也是[b,50,128]。实现的过程因为使用了torch.einsum所以有些复杂,但是总的来说,和我之前讲过的一篇论文"non-local"模块,是完全一样的。torch.einsum和torch.mm原理相同,只是因为torch.mm不支持高纬度的张量做矩阵乘法。

PreNorm

class PreNorm(nn.Module):
def __init__(self, dim, fn):
# dim=128,fn=Attention/FeedForward
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

先对输入的x(x.shape=[b,50,128])做一个layerNormalization层归一化,然后再放到上面的Attention模块中做自注意力。

Residual

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

一个残差模块罢了。

FeedForward

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
# dim=128,hidden_dim=128
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

就是两个线性层,这里有意思的是GELU()激活函数,这个激活函数可以直接使用torch.nn.GELU()调用,回头有机会再好好讲一下GELU()的原理。

transformer总结

Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
  • 第一个就是,先对输入做layerNormalization,然后放到attention得到attention的结果,然后结果和做layerNormalization之前的输入相加做一个残差链接;
  • 第二个就是,x->LayerNormalization->FeedForward线性层->y,然后这个y和输入的x相加,做残差连接。

VIT总结

回顾一下整个流程:

  • 一个图片224x224,分成了49个32x32的patch;
  • 对这么多的patch做embedding,成49个128向量;
  • 再拼接一个cls_tokens,变成50个128向量;
  • 再加上pos_embedding,还是50个128向量;
  • 这些向量输入到transformer中进行自注意力的特征提取;
  • 输出的是50个128向量,然后对这个50个求军职,变成一个128向量;
  • 然后线性层把128维变成2维从而完成二分类任务的transformer模型。

问题:我对NLP了解不深入,有没有人可以回答一下这个问题:cls_tokens和pos_embedding的用处是什么?

VIT Vision Transformer | 先从PyTorch代码了解的更多相关文章

  1. ICCV2021 | 渐进采样式Vision Transformer

    ​  前言  ViT通过简单地将图像分割成固定长度的tokens,并使用transformer来学习这些tokens之间的关系.tokens化可能会破坏对象结构,将网格分配给背景等不感兴趣的区域,并引 ...

  2. ICCV2021 | Tokens-to-Token ViT:在ImageNet上从零训练Vision Transformer

    ​  前言  本文介绍一种新的tokens-to-token Vision Transformer(T2T-ViT),T2T-ViT将原始ViT的参数数量和MAC减少了一半,同时在ImageNet上从 ...

  3. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  4. ICCV2021 | Vision Transformer中相对位置编码的反思与改进

    ​前言  在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...

  5. (原)SphereFace及其pytorch代码

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/8524937.html 论文: SphereFace: Deep Hypersphere Embeddi ...

  6. (转载)PyTorch代码规范最佳实践和样式指南

    A PyTorch Tools, best practices & Styleguide 中文版:PyTorch代码规范最佳实践和样式指南 This is not an official st ...

  7. PyTorch代码调试利器: 自动print每行代码的Tensor信息

    本文介绍一个用于 PyTorch 代码的实用工具 TorchSnooper.作者是TorchSnooper的作者,也是PyTorch开发者之一. GitHub 项目地址: https://github ...

  8. 如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)

    之前讲解了图注意力网络的官方tensorflow版的实现,由于自己更了解pytorch,所以打算将其改写为pytorch版本的. 对于图注意力网络还不了解的可以先去看看tensorflow版本的代码, ...

  9. pointnet.pytorch代码解析

    pointnet.pytorch代码解析 代码运行 Training cd utils python train_classification.py --dataset <dataset pat ...

随机推荐

  1. ES6 class类 静态方法及类的继承

    一.class类 ES6之前都是定义函数以及函数的原型对象实现类型, 如果想要实现共享构造函数成员,可以用prototype来共享实现 ES6出现之后,使用class类的概念来实现原型的继承 二,静态 ...

  2. 【JavaWeb】JSON 文件

    JSON 文件 什么是 JSON JSON(JavaScript Object Notation),即 JS 对象符号. 是一种轻量级(相对于 XML 来说)的数据交换格式,易于阅读和编写,同时也易于 ...

  3. JavaScript入门-对象

    js对象 本篇主要介绍js里如何创建对象,以及for循环访问对象的成员... 什么是对象? 对象,并不是中文里有男女朋友意思,它是从英文里翻译来的,英文叫[Object],目标,物体,物品的意思. 在 ...

  4. LRU(Least Recently Used)最近未使用置换算法--c实现

    在OS中,一些程序的大小超过内存的大小(比如好几十G的游戏要在16G的内存上跑),便产生了虚拟内存的概念 我们通过给每个进程适当的物理块(内存),只让经常被调用的页面常驻在物理块上,不常用的页面就放在 ...

  5. CTFHub - Web(二)

    目录遍历: 法一: 依次查看目录即可: 法二: 利用脚本:  #!/usr/bin/python3  # -*- coding: utf-8 -*-  # --author:valecalida-- ...

  6. 24V转3.3V稳压芯片,高效率同步降压DC-DC变换器3A输出电流

    PW2312是一个高频,同步,整流,降压,开关模式转换器与内部功率MOSFET.它提供了一个非常紧凑的解决方案,以实现1.5A的峰值输出电流在广泛的输入电源范围内,具有良好的负载和线路调节. PW23 ...

  7. python_mmdt:一种基于敏感哈希生成特征向量的python库(一)

    概述 python_mmdt是一种基于敏感哈希的特征向量生成工具.核心算法使用C实现,提高程序执行效率.同时使用python进行封装,方便研究人员使用. 本篇幅主要介绍涉及的相关基本内容与使用,相关内 ...

  8. 解决Python内CvCapture视频文件格式不支持问题

    解决Python内CvCapture视频文件格式不支持问题 在读取视频文件调用默认的摄像头cv.VideoCapture(0)会出现下面的视频格式问题 CvCapture_MSMF::initStre ...

  9. web框架的本质:

    简单的web框架 web的应用本质其实就是socket服务器,用户所使用的浏览器就是一个cocket客户端,客户使用浏览器发送的请求会被服务接收,服务器会按照http协议的响应协议来回复请求,这样的网 ...

  10. udp 连接

    在今天的内容里,我对 UDP 套接字调用 connect 方法进行了深入的分析.之所以对 UDP 使用 connect,绑定本地地址和端口,是为了让我们的程序可以快速获取异步错误信息的通知,同时也可以 ...