越简单越接近本质。

参考资料

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract & Introduction

论文中有几个关键词:

  • contracting path 收缩路径;
  • expansive path 扩张路径;
  • precise localization 更精确的位置信息;
  • overlap-tile 边界镜像翻转;
  • random elastic deformations 随机弹性形变;
  • invariance 尺度不变性;
  • touching cells 指距离很近的两个细胞;
  • seamless tilling 无缝拼接;

好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。

首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:

Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)

  • 非常慢,计算冗余(sliding window的毛病大家都懂);
  • 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。

作者的输入多层特征的思想是受以下论文启发:

  • Hypercolumns for object segmentation and fine-grained localization (2014)
  • Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)

这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。

U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。

由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture

网络结构

使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。

全部使用ReLU激活函数。

权值初始化使用何恺明的方法:

Surpassing humanlevel performance on imagenet classification

具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)

训练

在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。

损失函数就是pixel-wise soft-max + cross_entropy了。

数据增强

随机弹性形变和weight map:

随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。

随机弹性形变的目的是让网络有invariance(尺度不变性)。

那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码

最后来看看代码吧:https://github.com/milesial/Pytorch-UNet

整体模型:

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes) def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)

细节部分:

class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
) def forward(self, x):
x = self.conv(x)
return x class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch) def forward(self, x):
x = self.conv(x)
return x class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
) def forward(self, x):
x = self.mpconv(x)
return x class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__() # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):
x1 = self.up(x1) # input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x):
x = self.conv(x)
return x

训练:

optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005) criterion = nn.BCELoss()

[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章

  1. [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation

    写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...

  2. 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  3. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  4. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  5. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  6. AI理论学习笔记(一):深度学习的前世今生

    AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. 论文笔记之:Visual Tracking with Fully Convolutional Networks

    论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015  CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...

  9. Twitter 新一代流处理利器——Heron 论文笔记之Heron架构

    Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...

随机推荐

  1. php提供一维数组模糊查询

    2019年9月30日14:36:15 提供一维数组模糊查询,只支持utf-8 内部处理是Unicode 编码特殊编码格式的可能会出错 if (!function_exists('arrayFuzzyQ ...

  2. js:如何在iframe重载前执行特定动作

    问题说明: 点击左侧菜单时,右侧页面中的iframe加载菜单内容,在iframe加载的页面A中使用了websocket.点击其它菜单时,无法主动关闭websocket, 可能会造成websocket链 ...

  3. Mac Pro 2015休眠掉电解决办法

    硬件:Mac Pro 2015 系统:MacOs Mojave 10.14.3 问题:合盖的时候,休眠1小时掉电10%,由于之前是128G原装盘不会有这个问题,后面购买了M.2转接卡,更换1T Int ...

  4. 创建Windows Service

    基本参照使用C#创建Windows服务,添加了部分内容 目录 创建Windows Service 可视化管理Windows Service 调试 示例代码 创建Windows Service 选择C# ...

  5. 解决SQL语句在Dapper执行超时比Query慢的问题

    在语句结尾加上 Add OPTION (RECOMPILE) to the end https://stackoverflow.com/questions/10933366/sp-executesql ...

  6. [ARM-Linux开发]Linux open函数

    Linux open函数 open 函数用于打开和创建文件.以下是 open 函数的简单描述 #include <fcntl.h> int open(const char *pathnam ...

  7. webapi+swagger ui 文档描述

    代码:GitHub swagger ui在我们的.NET CORE和.NET Framework中的展现形式是不一样的,如果有了解的,在.NET CORE中的是比.NET Framework好的.两张 ...

  8. 资源对象的池化, java极简实现,close资源时,自动回收

    https://www.cnblogs.com/piepie/p/10498953.html 在java程序中对于资源,例如数据库连接,这类不能并行共享的资源对象,一般采用资源池的方式进行管理. 资源 ...

  9. ubuntu supervisor管理uwsgi+nginx

    一.概述 superviosr是一个Linux/Unix系统上的进程监控工具,他/她upervisor是一个Python开发的通用的进程管理程序,可以管理和监控Linux上面的进程,能将一个普通的命令 ...

  10. N皇后问题的python实现

    数据结构中常见的问题,最近复习到了,用python做一遍. # 检测(x,y)这个位置是否合法(不会被其他皇后攻击到) def is_attack(queue, x, y): for i in ran ...