越简单越接近本质。

参考资料

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract & Introduction

论文中有几个关键词:

  • contracting path 收缩路径;
  • expansive path 扩张路径;
  • precise localization 更精确的位置信息;
  • overlap-tile 边界镜像翻转;
  • random elastic deformations 随机弹性形变;
  • invariance 尺度不变性;
  • touching cells 指距离很近的两个细胞;
  • seamless tilling 无缝拼接;

好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。

首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:

Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)

  • 非常慢,计算冗余(sliding window的毛病大家都懂);
  • 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。

作者的输入多层特征的思想是受以下论文启发:

  • Hypercolumns for object segmentation and fine-grained localization (2014)
  • Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)

这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。

U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。

由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture

网络结构

使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。

全部使用ReLU激活函数。

权值初始化使用何恺明的方法:

Surpassing humanlevel performance on imagenet classification

具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)

训练

在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。

损失函数就是pixel-wise soft-max + cross_entropy了。

数据增强

随机弹性形变和weight map:

随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。

随机弹性形变的目的是让网络有invariance(尺度不变性)。

那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码

最后来看看代码吧:https://github.com/milesial/Pytorch-UNet

整体模型:

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes) def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)

细节部分:

class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
) def forward(self, x):
x = self.conv(x)
return x class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch) def forward(self, x):
x = self.conv(x)
return x class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
) def forward(self, x):
x = self.mpconv(x)
return x class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__() # would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):
x1 = self.up(x1) # input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x):
x = self.conv(x)
return x

训练:

optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005) criterion = nn.BCELoss()

[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章

  1. [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation

    写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...

  2. 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs

    论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...

  3. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  4. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  5. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  6. AI理论学习笔记(一):深度学习的前世今生

    AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. 论文笔记之:Visual Tracking with Fully Convolutional Networks

    论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015  CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...

  9. Twitter 新一代流处理利器——Heron 论文笔记之Heron架构

    Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...

随机推荐

  1. web端测试总结

    1.数值型输入框: 条件:demcial(x,y) ,界面显示小数点到y位 通常要检查以下几点: (1)边界值:最大值.最小值.最大值+1.最小值-1  (2)位数:最小位数.最大位数.最小位数-1最 ...

  2. Java调用动态链接库so文件(传参以及处理返回值问题)

    刚来到公司,屁股还没坐稳,老板把我叫到办公室,就让我做一个小程序.我瞬间懵逼了.对小程序一窍不通,还好通过学习小程序视频,两天的时间就做了一个云开发的小程序,但是领导不想核心的代码被别人看到,给了我一 ...

  3. 【idea】全局搜索、替换只显示100条的问题

    没有修改之前 修改之后 如果用的是idea默认的快捷键,按Ctrl+Shift+A,然后输入Registry 如果是eclipse的快捷键

  4. GPU机器安装paddle

    安装基础包 yum -y install epel-release yum -y install kernel-devel yum -y install dkms 编辑文件 /etc/default/ ...

  5. [转帖]TimesTen与Redis的对比

    TimesTen与Redis的对比 2017-02-23 17:25:27 dingdingfish 阅读数 3682更多 分类专栏: TimesTen Oracle Redis In-Memory ...

  6. AtCoder-abc147 (题解)

    A - Blackjack (水题) 题目链接 大致思路: 水题 B - Palindrome-philia (水题) 题目链接 大致思路: 由于整个串是回文串,只要判断前一半和后一半有多少个不同即可 ...

  7. Python基础(七)——文件和异常

    1.1 读取整个文件 我们可以创建一个 test.txt 并写入一些内容,使用 Python 读文件操作,读出文本内容. with open(r'E:\test.txt') as file_objec ...

  8. python_dict json读写文件

    命令汇总: json.dumps(obj) 将python数据转化为json Indent实现缩进,ensure_ascii 是否用ascii解析 json.loads(s) 将json数据转换为py ...

  9. 解决fiddler不能抓取firefox浏览器包的问题(转)

    转自:https://blog.csdn.net/jimmyandrushking/article/details/80819103

  10. .NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化

    原文:.NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化 最近我大幅度重构了我一个库的项目结构,使之使用最新的项目文件格式(基于 Microsoft.NET.Sdk)并使 ...