越简单越接近本质。

参考资料

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract & Introduction

论文中有几个关键词:

  • contracting path 收缩路径;
  • expansive path 扩张路径;
  • precise localization 更精确的位置信息;
  • overlap-tile 边界镜像翻转;
  • random elastic deformations 随机弹性形变;
  • invariance 尺度不变性;
  • touching cells 指距离很近的两个细胞;
  • seamless tilling 无缝拼接;

好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。

首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:

Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)

  • 非常慢,计算冗余(sliding window的毛病大家都懂);
  • 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。

作者的输入多层特征的思想是受以下论文启发:

  • Hypercolumns for object segmentation and fine-grained localization (2014)
  • Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)

这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。

U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。

由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture

网络结构

使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。

全部使用ReLU激活函数。

权值初始化使用何恺明的方法:

Surpassing humanlevel performance on imagenet classification

具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)

训练

在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。

损失函数就是pixel-wise soft-max + cross_entropy了。

数据增强

随机弹性形变和weight map:

随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。

随机弹性形变的目的是让网络有invariance(尺度不变性)。

那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码

最后来看看代码吧:https://github.com/milesial/Pytorch-UNet

整体模型:

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes) def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)

细节部分:

class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
) def forward(self, x):
x = self.conv(x)
return x class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch) def forward(self, x):
x = self.conv(x)
return x class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
) def forward(self, x):
x = self.mpconv(x)
return x class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__() # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):
x1 = self.up(x1) # input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x):
x = self.conv(x)
return x

训练:

optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005) criterion = nn.BCELoss()

[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章

  1. [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation

    写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...

  2. 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  3. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  4. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  5. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  6. AI理论学习笔记(一):深度学习的前世今生

    AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. 论文笔记之:Visual Tracking with Fully Convolutional Networks

    论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015  CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...

  9. Twitter 新一代流处理利器——Heron 论文笔记之Heron架构

    Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...

随机推荐

  1. HTML5 VUE单页应用 SEO 优化之 预渲染(prerender-spa-plugin)

    前言:当前 SPA 架构流行的趋势如日中天,前后端分离的业务模式已经成为互联网开发的主流方式,但是 单页面 应用始终存在一个痛点,那就是 SEO, 对于那些需要推广,希望能在百度搜索时排名靠前的网站而 ...

  2. phpspreadsheet 中文文档(三) 计算引擎

    2019年10月11日13:59:52 使用PhpSpreadsheet计算引擎 执行公式计算 由于PhpSpreadsheet表示内存中的电子表格,因此它还提供公式计算功能.单元格可以是值类型(包含 ...

  3. IntelliJ IDEA(2018.3.5) 设置编码为utf-8编码

    位置一: File->Settings->Editor->File Encodings   位置二: File->Other Settings->Default Sett ...

  4. oracle全库查找是否有某个值

    在scott用户下面,搜索含有'要找的值'的数据的表和字段穷举法: declare v_Sql ); v_count number; begin for xx in (select t.OWNER, ...

  5. 【LOJ2292】[THUSC2016]成绩单(区间DP)

    题目 LOJ2292 分析 比较神奇的一个区间 DP ,我看了很多题解都没看懂,大约是我比较菜罢. 先明确一下题意:abcde 取完 c 后变成 abde ,可以取 bd 这样取 c 后新增的连续段. ...

  6. PHP设计模式 - 中介者模式

    中介者模式用于开发一个对象,这个对象能够在类似对象相互之间不直接相互的情况下传送或者调解对这些对象的集合的修改. 一般处理具有类似属性,需要保持同步的非耦合对象时,最佳的做法就是中介者模式.PHP中不 ...

  7. javascript中 for in 、for 、forEach 、for of 、Object.keys().

    一 .for ..in 循环 使用for..in循环时,返回的是所有能够通过对象访问的.可枚举的属性,既包括存在于实例中的属性,也包括存在于原型中的实例.这里需要注意的是使用for-in返回的属性因各 ...

  8. 设计模式 AOP,OOP

    AOP.OOP在字面上虽然非常类似,但却是面向不同领域的两种设计思想. 简单说,AOP面向动词领域,OOP面向名词领域 AOP: (Aspect Oriented Programming) 面向切面编 ...

  9. 简单工厂(SimpleFactory)模式

    简单工厂模式是类的创建模式,又叫做静态工厂方法(Static Factory Method)模式.简单工厂模式是由一个工厂对象决定创建出哪一种产品类的实例. 简单工厂就是将多个if,else...代码 ...

  10. JVM Server与Client运行模式

    JVM Server模式与client模式启动,最主要的差别在于:-Server模式启动时,速度较慢,但是一旦运行起来后,性能将会有很大的提升.原因是: 当虚拟机运行在-client模式的时候,使用的 ...