自监督图像论文复现 | BYOL(pytorch)| 2020
继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:
https://juejin.cn/post/6922347006144970760
现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。
github:https://github.com/lucidrains/byol-pytorch
【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。
主要模型代码
class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
):
super().__init__()
self.net = net
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(self, x, return_embedding = False):
if return_embedding:
return self.online_encoder(x)
image_one, image_two = self.augment1(x), self.augment2(x)
online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one.detach_()
target_proj_two.detach_()
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()
- 先看
forward()函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss - 如果是推理过程,那么
return_embedding=True,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector; - 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
- 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
- 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
- 如果
self.use_momentum=False,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。 - 最后计算通过
loss_fn计算损失,然后return loss.mean()
所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:
- online_encoder如何定义?
- predictor如何定义?
- 图像增强方法如何定义?
- loss_fn损失函数如何定义?
augment
从上面的代码中可以看到这一段:
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
可以看到:
- 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
from torchvision import transforms as T
比较陌生的可能就是torchvision.transforms.ColorJitter()这个方法了。
从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调
encoder+projector
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = flatten(output)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_embedding = False):
representation = self.get_representation(x)
if return_embedding:
return representation
projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation
这个就是基本的encoder+projector,里面包含encoder和projector。
encoder
这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:
from torchvision import models, transforms
resnet = models.resnet50(pretrained=True)
所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。
projector
调用到了MLP这个东西:
class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size = 4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
return self.net(x)
是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。
predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size。
这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设
loss_fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
这部分和论文中一致。
综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。
自监督图像论文复现 | BYOL(pytorch)| 2020的更多相关文章
- Visualizing and Understanding Convolutional Networks论文复现笔记
目录 Visualizing and Understanding Convolutional Networks 论文复现笔记 Abstract Introduction Approach Visual ...
- Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...
- 图像风格迁移(Pytorch)
图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...
- 小白经典CNN论文复现系列(一):LeNet1989
小白的经典CNN复现系列(一):LeNet-1989 之前的浙大AI作业的那个系列,因为后面的NLP的东西我最近大概是不会接触到,所以我们先换一个系列开始更新博客,就是现在这个经典的CNN复现啦(。・ ...
- GAN生成图像论文总结
GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN DCGAN WGAN Least-square GAN Loss Sensi ...
- 化繁为简,弱监督目标定位领域的新SOTA - 伪监督目标定位方法(PSOL) | CVPR 2020
论文提出伪监督目标定位方法(PSOL)来解决目前弱监督目标定位方法的问题,该方法将定位与分类分开成两个独立的网络,然后在训练集上使用Deep descriptor transformation(DDT ...
- 训练一个图像分类器demo in PyTorch【学习笔记】
[学习源]Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier 本文相当于 ...
- 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)
项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...
- 复现ICCV 2017经典论文—PyraNet
. 过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”.这是今年 AAAI 会议上一个严峻的 ...
随机推荐
- JPA 缓存
JPA有两种类型的缓存: EntityManager自身就是一种缓存.事务中从数据库获取的和写入到数据库的数据会被缓存(什么样的数据会被缓存,在后面有介绍).在一个程序中也许会有很多个不同的Entit ...
- 在jsp页面动态添加,删除文本框模块
jsp代码============ <table class="crud-content-info" > <tr > <td align=" ...
- easyui 动态添加input标签
动态添加easyui控件<input class=" easyui-textbox" > 这样是无效的,因为easyui没有实时监控,所以必须动态渲染$.parser. ...
- [LeetCode]160. Intersection of Two Linked Lists判断交叉链表的交点
方法要记住,和判断是不是交叉链表不一样 方法是将两条链表的路径合并,两个指针分别从a和b走不同路线会在交点处相遇 public ListNode getIntersectionNode(ListNod ...
- Java ClassLoader浅析
双亲委派 提起 java 类加载器,自然绕不开其双亲委派模型 什么是双亲委派 提起双亲委派,首先想到便是那张经典的向上委派图 一般场景下,当某个类将要被加载时,由系统上下文默认的类加载器Thread. ...
- 每日一个linux命令6 -- rmdir
rmdir doc 如果doc为空目录则删除,否则无法删除. rmdir -p test2/test3 递归删除空目录,首先判断test3,如果test3为空,则删除test3,此时判断test2,如 ...
- netty心跳检测机制
既然是网络通信那么心跳检测肯定是离不开的,netty心跳检测分为读.写.全局 bootstrap.childHandler(new ChannelInitializer<SocketChanne ...
- linux mysql source 导入大文件报错解决办法
找到mysql的配置文件目录 my.cnf interactive_timeout = 120wait_timeout = 120max_allowed_packet = 500M 在导入过程中可能会 ...
- VMware 安装 Centos7 超详细过程
https://www.runoob.com/w3cnote/vmware-install-centos7.html centos7安装参考文档 VMware 安装 Centos7 超详细过程 分类 ...
- #2020征文-开发板#使用Python开发鸿蒙应用--2021.01.07直播图文
写在前面: 每年的过年前夕,手中的项目一定会告急...而自己又缺乏三头六臂七十二变等特技,所以只能在鸿蒙社区先消失一阵子了.今天再看社区的帖子,发现大家的进步可不一般,各种案例示例层出不穷,一片欣欣向 ...