前言

本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!带大家深入语义分割的领域,将从原理,代码深入讲解,希望大家能从中有所收获,其中很多内容都包含着自己的一些想法以及理解,如果有错误的地方欢迎大家批评指正。

论文名称:《U-Net: Convolutional Networks for Biomedical Image Segmentation》

论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation

欢迎继续来到语义分割专栏系列的第二篇,本文将继续带大家来学习语义分割领域的经典模型:U-net

背景介绍

在上一篇中我们已经详细的介绍过了FCN的原理以及代码实现,本篇中我们要介绍的是U-net,是遵循着FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。

首先我们来看U-net出现的历史背景,U-net的设计初衷是为了解决医学图像分割中的挑战,但是因为其优秀合理的架构,故其也被广泛应用于各种语义分割任务中,包括卫星图像分析、农业图像处理、自动驾驶、遥感图像等。

首先我们来看当时在医学图像分割面临的哪些困难?

首先就是数据稀缺与数据标注困难:因为其专业性,所以医学图像标注成本都会比较高,需要专业医生;同时处于伦理道德方面,其可用于训练的图像数据数量远少于自然图像数据(如ImageNet等);其次就是分割精度需求高:因为在医疗领域,即使像素级的微小误差也可能影响诊断结果;所以我们的模型就需要能够兼顾全局上下文局部细节

那么如果是你,在当时的背景条件下,你希望能够在医学图像分割中取得更好的结果,你会怎么做?首先我们来想,我们的数据训练量不会很多,那么我们的模型肯定不能够过于复杂了,其网络设计一定是参数比较小的,适合小样本训练的,但是去设计个咋样的架构才是合适的呢?这就是很具有创新性的一步了。其次分割精度需求高,那么我们的模型肯定要能够兼顾上下文的信息和局部的细节信息。其实这个在FCN已经给出了个答案,通过跳跃连接的方式,兼顾上下文信息,但是U-net的跳跃连接方式又有着一些差别

U-net核心剖析

是的,你遇到这些问题你怎么办?希望大家都能够设身处地去思考,我们才能够明白每一项工作的创新意义,同时进行深入思考,很多时候你可能也会有自己的想法,每个创新性的想法都是不经意间的,希望大家都能够有思维的碰撞。好了,话说回来,我们来看看U-net的作者时怎么做的,就是那些问题,看看人家如何进行解决的。

编码解码结构(U形状)

我始终觉得U-net是一个超级符合对称美学的架构,左侧的结构是特征提取部分,右侧的结构就是上采样部分,当然也有人将其称为编码器解码器结构。同时由于其网络的整体结构像一个大写的英文字母U,所以叫做U-net。

在Encoder部分,我们将输入图像经过多个卷积和池化操作,逐步提取其语义特征。每经过一个 2x2 最大池化层(红色箭头,采用的是无填充的方式),所以feature map的尺寸会减半。

在Decoder部分,通过 2x2 反卷积操作(绿色箭头)将特征图尺寸逐步放大两倍。上采样后的特征图与编码器部分对应尺度的特征图进行concat融合(蓝色箭头代表 3x3 卷积操作,步长为 1,有效填充,每次操作后特征图尺寸减少 2)。为了进行拼接,需要对尺寸较大的feature map进行crop操作(灰色箭头),使其与上采样后的特征图尺寸匹配。

最终输出层使用 1x1 卷积层进行分类,输出两层,分别代表前景和背景。输入图像为 572x572,输出图像为 388x388,说明经过网络处理后,输出结果与原图尺寸不完全对应。(这里大家可能会有点疑问,为什么语义分割任务会输入输出的shape不匹配,这个后面会有说明。)

卷积模式

这里我们可以看到一个细节,我们在Encoder部分的时候卷积之后,feature map的shape都会减小,这是因为其采用的卷积模式是valid。卷积一共有三个模式,分别是full mode、same mode、valid mode

  • full mode:从卷积核刚开始与我们的图像进行相交的时候就开始卷积操作。
  • same mode:当卷积核的中心(K)与图像的边角重合时,我们就开始做卷积运算,可见卷积核的运动范围比full模式小了一圈。当然了这里的same还有一个意思,就是当卷积之后输出的feature map尺寸保持不变(相对于输入图片)。
  • valid mode:当卷积核全部在图像里面的时候,进行卷积运算,可见卷积核的移动范围较same更小了。

这里我们我用蓝色表示卷积核,橙色表示我们的图像部分,白色部分位填充,相信通过下图能够更加清晰的了解不同的卷积方式了。

跳跃连接

不知道大家还记不记得我们之前在FCN中也讲到了跳跃连接的,我们这里回顾下:

首先我们需要明白一个事情就是,我们的网络在进行特征提取的时候是从低级语义信息不断不断的进行提取到最后的输出的高级语义信息的。网络的低层提取的语义信息更多代表了图像的纹理、边缘等一些显性的信息,网络的高层所提取的一些语义信息更多的就是其数据核心的抽象的语义信息了。那我们最后进行语义分割的特征图的语义信息肯定损失了很多关键的细节、边缘信息了,并且最后还会有上采样的过程,这个现象就会更加加剧。我们想要最后进行语义分割也能够有些这些细节信息怎么办?

这就是跳跃连接了。想要低层语义信息,直接把低层的语义信息加回来不就好了,简单粗暴,但是同样的也非常有效。这就是我们在FCN中跳跃连接的方式,直接将对应位置的信息进行相加,即就是相当于是add操作。

但是在U-net中的跳跃连接方式是concat,从图中也能看出,我们是将之前的低级语义信息与我们在后来提取到的高级语义信息进行通道上的相加了,不是对应位置像素直接相加。那么二者有什么区别呢?

add

我们来看,以下是 keras 中对 add 的实现源码,pytorch的封装更复杂一些,不过原理都是一样的,看这个就行:

def _merge_function(self, inputs):
output = inputs[0]
for i in range(1, len(inputs)):
output += inputs[i]
return output

其中 inputs 为待融合的特征图,inputs[0]、inputs[1]……等的通道数一样,且特征图宽与高也一样。

从代码中可以很容易地看出,add 方式有以下特点

  1. 做的是对应通道对应位置的值的相加,通道数不变
  2. 描述图像的特征个数不变,但是每个特征下的信息却增加了。

concat

同样的,我们通过阅读下面代码实例帮助理解 concat 的工作原理:

import torch

# 创建两个张量
t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = torch.tensor([[7, 8, 9], [10, 11, 12]]) # 沿第1维拼接
result_1 = torch.cat([t1, t2], dim=1)
print(result_1)
# 输出: tensor([[ 1, 2, 3, 7, 8, 9],
# [ 4, 5, 6, 10, 11, 12]])

在模型网路当中,数据通常为 4 个维度,即 num×channels×height×width ,因此默认值 1 表示的是 channels 通道进行拼接。如:

combine = torch.cat([d1, add1, add2, add3, add4], 1)

从代码中可以很容易地看出,concat 方式有以下特点:

  1. 做的是通道的合并,通道数变多了
  2. 描述图像的特征个数变多,但是每个特征下的信息却不变。

所以到这里,我们就能够很清晰的知道add操作和concat操作的不同了。

操作 描述 优点 缺点 补充
add - 相当于加了一种prior - 要求两路输入的对应通道特征图语义类似 - 计算量少 - 特征提取能力差 - 对应通道信息类似时,可融合多通道信息 - 尺度不一致时,小尺度特征可能被淹没
concat - 通过训练学习整合两个特征图通道之间的信息 - 特征提取能力强 - 计算量大(是add的2倍) - 能提取更合适的信息,效果更好

其他细节

overlap-tile策略

因为医学图像是一般都是相当大的,我们在分割的时候就不可能将原图直接输入网络,所以需要用一个滑动窗口把原图扫一遍,使用原图的切片进行训练或测试。可以看图,其中红框标出来的是要分割区域。但是我们在切图时要包含周围区域,overlap另一个重要原因是周围overlap部分可以为分割区域边缘部分提供纹理等信息。

但是这样的策略会带来一个问题,图像边界的图像块没有周围像素,卷积会使图像边缘处的信息丢失。因此其对周围像素采用了镜像扩充。下图中红框部分为原始图片,其周围扩充的像素点均由原图沿白线对称得到。这样,边界图像块也能得到准确的预测。

另一个问题是,这样的操作会带来图像重叠问题,即第一块图像周围的部分会和第二块图像重叠。所以还记得我们之前讲解网络结构的时候吗?其输入图像为 572x572,但是最终的输出图像为 388x388,我认为就是通过这样的方式和我们在concat时候的crop操作来让模型只关注图像的黄色区域内的部分。

弹性形变

为了解决任务中数据缺乏的问题,我们常常都是会采用一些数据增强的方法来扩充数据集。常见的增强方式包括对图像进行旋转、平移等仿射变换,或进行镜像处理。在此基础上,U-net 论文中使用一种更适合医学图像的数据增强方式——弹性变换。该方法最初在 MNIST 手写数字识别任务中使用,发现通过对原图进行弹性变形可以显著提升模型识别准确率。因为U-Net 处理的图像数据来自细胞组织,而细胞边界本身就具有自然的、不规则形变特性,因此使用弹性变换可以模拟真实情况下的结构畸变,从而提升模型的泛化能力。

弹性变换的基本原理是:为图像的每个像素坐标引入一个在 (−1,1)(-1, 1) 区间内的随机扰动,这些扰动通过高斯滤波平滑后,再乘以一个缩放系数来控制最终的形变幅度。最终,原图中位置 (x,y)(x, y) 的像素被映射到新的位置 $(x+δ_x,y+δ_y)(x + \delta_x, y + \delta_y)$,新图像的像素值通过插值从原图获得,即新位置的值来自原图对应位置的值。

图示中展示了在相同扰动强度下,不同高斯标准差带来的形变效果。结果表明,第二幅图的形变效果在真实感和增强效果之间达到了较好的平衡。

这个时候我们在回头看,最初的核心两个问题:数据稀缺与数据标注困难分割精度需求高

通过设计了轻巧的U型网络,采用了大量数据增强的方式,使得其能够更好的适应小样本的任务。通过多尺度融合 + 跳跃连接,提升了对小物体和边界的感知能力;并且跳跃连接还能够避免深层网络中“语义信息丰富但空间信息丢失”的问题,从而能够保证分割精度。

U-net模型代码

这里同样的 我自己也尝试去复现了U-net模型代码,当然细节上跟原论文中的U-net不是完全一样,原来的U-net模型是适用于医学图像分割任务,所以其有部分设计也是为了医学图像分割设计的,我这里复现的U-net代码更适合普遍的语义分割任务,其输入输出的shape大小是相同的。

首先是我将所有的上采样下采样中的卷积部分集成到了一起,看模型结构能够看出,每个部分都是两次卷积,所以代码如下,就在设置不同stage的时候设置好输入输出通道即可。

class Down_Up_Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(Down_Up_Conv, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU()
) def forward(self, x):
return self.conv_block(x)

然后这是跳跃连接的代码,同时我们采取了crop操作。我们通过获取两个feature map的长宽,然后再对齐之后进行再通道维上的拼接,代码如下,还是比较好理解的。

def crop_and_concat(upsampled, bypass):
"""
将两个 feature map 在 H 和 W 上对齐后拼接(dim=1)
- upsampled: 解码器上采样后的特征图 (N, C1, H1, W1)
- bypass: 编码器传来的特征图 (N, C2, H2, W2)
"""
h1, w1 = upsampled.shape[2], upsampled.shape[3]
h2, w2 = bypass.shape[2], bypass.shape[3] # 计算差值
delta_h = h2 - h1
delta_w = w2 - w1 # 对 encoder 输出进行中心裁剪
bypass_cropped = bypass[:, :,
delta_h // 2: delta_h // 2 + h1,
delta_w // 2: delta_w // 2 + w1] # 拼接通道维
return torch.cat([upsampled, bypass_cropped], dim=1)

然后就是搭建我们的U-net模型了,这还是比较容易的,将encoder部分的五个阶段的下采样卷积定义好,注意通道数的变换,然后就是Decoder的上采样的过程,我们使用的是转置卷积,上采样后还有卷积过程,所以我们按照U-net的模型图搭建即可。注意,我这里是把maxpooling给摘出来了的,每个下采样卷积之后都会有一个maxpooling层,这个可别忘了,在forward里面有体现。定义好模型参数之后就是模型参数的初始化了,这个步骤可千万不能忘。

class UNet(nn.Module):
def __init__(self, num_classes=2):
super(UNet, self).__init__()
self.stage_down1=Down_Up_Conv(3, 64)
self.stage_down2=Down_Up_Conv(64, 128)
self.stage_down3=Down_Up_Conv(128, 256)
self.stage_down4=Down_Up_Conv(256, 512)
self.stage_down5=Down_Up_Conv(512, 1024) self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2,padding=1)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2,padding=1)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1) self.stage_up4=Down_Up_Conv(1024, 512)
self.stage_up3=Down_Up_Conv(512, 256)
self.stage_up2=Down_Up_Conv(256, 128)
self.stage_up1=Down_Up_Conv(128, 64)
self.stage_out=Down_Up_Conv(64, num_classes)
self.maxpool = nn.MaxPool2d(kernel_size=2) self.initialize_weights() def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) def forward(self, x):
stage1 = self.stage_down1(x)
x = self.maxpool(stage1)
stage2 = self.stage_down2(x)
x = self.maxpool(stage2)
stage3 = self.stage_down3(x)
x = self.maxpool(stage3)
stage4 = self.stage_down4(x)
x = self.maxpool(stage4)
stage5 = self.stage_down5(x) x = self.up4(stage5) x = self.stage_up4(crop_and_concat(x, stage4))
x = self.up3(x)
x = self.stage_up3(crop_and_concat(x, stage3))
x = self.up2(x)
x = self.stage_up2(crop_and_concat(x, stage2))
x = self.up1(x)
x = self.stage_up1(crop_and_concat(x, stage1))
out = self.stage_out(x)
return out

结语

希望上列所述内容对你有所帮助,如果有错误的地方欢迎大家批评指正!

并且如果可以的话希望大家能够三连鼓励一下,谢谢大家!

如果你觉得讲的还不错想转载,可以直接转载,不过麻烦指出本文来源出处即可,谢谢!

参考资料

本文参考了下列的文章内容,集百家之长汇聚于此,同时包含自己的思考想法

UNET详解和UNET++介绍(零基础)-CSDN博客

图像分割必备知识点 | Unet详解 理论+ 代码 - 知乎

深度学习系列-UNet网络 - 知乎

【语义分割专栏】2:U-net原理篇(由浅入深)的更多相关文章

  1. 多篇开源CVPR 2020 语义分割论文

    多篇开源CVPR 2020 语义分割论文 前言 1. DynamicRouting:针对语义分割的动态路径选择网络 Learning Dynamic Routing for Semantic Segm ...

  2. 几篇关于RGBD语义分割文章的总结

      最近在调研3D算法方面的工作,整理了几篇多视角学习的文章.还没调研完,先写个大概.   基于RGBD的语义分割的工作重点主要集中在如何将RGB信息和Depth信息融合,主要分为三类:省略. 目录 ...

  3. 语义分割--全卷积网络FCN详解

    语义分割--全卷积网络FCN详解   1.FCN概述 CNN做图像分类甚至做目标检测的效果已经被证明并广泛应用,图像语义分割本质上也可以认为是稠密的目标识别(需要预测每个像素点的类别). 传统的基于C ...

  4. 语义分割:基于openCV和深度学习(一)

    语义分割:基于openCV和深度学习(一) Semantic segmentation with OpenCV and deep learning 介绍如何使用OpenCV.深度学习和ENet架构执行 ...

  5. caffe初步实践---------使用训练好的模型完成语义分割任务

    caffe刚刚安装配置结束,乘热打铁! (一)环境准备 前面我有两篇文章写到caffe的搭建,第一篇cpu only ,第二篇是在服务器上搭建的,其中第二篇因为硬件环境更佳我们的步骤稍显复杂.其实,第 ...

  6. 【Keras】基于SegNet和U-Net的遥感图像语义分割

    上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末project我们组选的课题也是遥感图像的语义分割,所以刚好又把前段时间做的成果重新 ...

  7. 笔记:基于DCNN的图像语义分割综述

    写在前面:一篇魏云超博士的综述论文,完整题目为<基于DCNN的图像语义分割综述>,在这里选择性摘抄和理解,以加深自己印象,同时达到对近年来图像语义分割历史学习和了解的目的,博古才能通今!感 ...

  8. 语义分割的简单指南 A Simple Guide to Semantic Segmentation

    语义分割是将标签分配给图像中的每个像素的过程.这与分类形成鲜明对比,其中单个标签被分配给整个图片.语义分段将同一类的多个对象视为单个实体.另一方面,实例分段将同一类的多个对象视为不同的单个对象(或实例 ...

  9. xgboost入门与实战(原理篇)

    sklearn实战-乳腺癌细胞数据挖掘 https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campai ...

  10. MIT提出精细到头发丝的语义分割技术,打造效果惊艳的特效电影

    来自 MIT CSAIL 的研究人员开发了一种精细程度远超传统语义分割方法的「语义软分割」技术,连头发都能清晰地在分割掩码中呈现.在对比实验中,他们的结果远远优于 PSPNet.Mask R-CNN. ...

随机推荐

  1. 线上测试木舟物联网平台之如何通过HTTP网络组件接入设备

    一.概述 木舟 (Kayak) 是什么? 木舟(Kayak)是基于.NET6.0软件环境下的surging微服务引擎进行开发的, 平台包含了微服务和物联网平台.支持异步和响应式编程开发,功能包含了物模 ...

  2. AI 代理的未来是事件驱动的

    AI 代理即将彻底改变企业运营,它们具备自主解决问题的能力.适应性工作流以及可扩展性.但真正的挑战并不是构建更好的模型. 代理需要访问数据.工具,并且能够在不同系统之间共享信息,其输出还需要能被多个服 ...

  3. 查看Unity3D中默认的变量名与按键的映射

    博客地址:https://www.cnblogs.com/zylyehuo/ 选择 Edit/Project Settings/Input Manager 点击 Axes 即可查看对应变量名与按键的映 ...

  4. 准确理解 JS 的 ++ 运算符

    对于刚开始接触前端开发的朋友们来说,可能地一个令人苦恼的问题是关于运算符 ++ 的计算,特别是它还有前置与后置的区别.当它们和一堆运算在一起的时候,常常令人头晕目眩! 我经常性地称它是一个***难人的 ...

  5. 【SpringCloud】SpringCloud Alibaba Nacos服务注册和配置中心

    SpringCloud Alibaba Nacos服务注册和配置中心 感悟 注意:凡是cloud里面,你要开哪个组件,新加哪个注解,第一个就是启动,如@EnableFeignClients,第二个就是 ...

  6. 探秘Transformer系列之(22)--- LoRA

    探秘Transformer系列之(22)--- LoRA 目录 探秘Transformer系列之(22)--- LoRA 0x00 概述 0x01 背景知识 1.1 微调 1.2 PEFT 1.3 秩 ...

  7. Windows7、Windows10跳过创建用户并直接用Administrator身份登录

    windows7 windows10跳过创建用户并直接用Administrator身份登录 一.操作方法: 在界面设置按 按 shift+f10 然后输入 lusrmgr.msc 用户管理控制台开启a ...

  8. 被LangChain4j坑惨了!

    最近在深度体验和使用 Spring AI 和 LangChain4j,从开始的满怀期待五五开,但最后极具痛苦的使用 LangChain4j,让我真正体验到了正规军和草台班子的区别. Spring AI ...

  9. 腾讯云短信发送【java】

    先去官网申请secretId, secretKey,然后创建对应的模板 maven引入包 <dependency> <groupId>com.tencentcloudapi&l ...

  10. 基础 DP 做题记录

    Luogu P1192 台阶问题 Link 简要题意: 给定台阶数 \(n\le10^5\) 和一步至多跨越台阶数 \(k\le10^2\) ,初始在 \(0\) 级,求方案数 \(\pmod {10 ...