[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络
越简单越接近本质。

参考资料
U-Net: Convolutional Networks for Biomedical Image Segmentation
Abstract & Introduction
论文中有几个关键词:
- contracting path 收缩路径;
- expansive path 扩张路径;
- precise localization 更精确的位置信息;
- overlap-tile 边界镜像翻转;
- random elastic deformations 随机弹性形变;
- invariance 尺度不变性;
- touching cells 指距离很近的两个细胞;
- seamless tilling 无缝拼接;
好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。
首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:
Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)
- 非常慢,计算冗余(sliding window的毛病大家都懂);
- 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。
作者的输入多层特征的思想是受以下论文启发:
- Hypercolumns for object segmentation and fine-grained localization (2014)
- Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)
这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。
U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。
由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture
网络结构
使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。
全部使用ReLU激活函数。
权值初始化使用何恺明的方法:
Surpassing humanlevel performance on imagenet classification
具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)
训练
在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。
损失函数就是pixel-wise soft-max + cross_entropy了。
数据增强
随机弹性形变和weight map:
随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。
随机弹性形变的目的是让网络有invariance(尺度不变性)。
那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码
最后来看看代码吧:https://github.com/milesial/Pytorch-UNet
整体模型:
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
细节部分:
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
训练:
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()
[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章
- [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation
写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...
- 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs
论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...
- Deep Learning论文笔记之(八)Deep Learning最新综述
Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...
- 论文笔记:Mastering the game of Go with deep neural networks and tree search
Mastering the game of Go with deep neural networks and tree search Nature 2015 这是本人论文笔记系列第二篇 Nature ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- AI理论学习笔记(一):深度学习的前世今生
AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- 论文笔记之:Visual Tracking with Fully Convolutional Networks
论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015 CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...
- Twitter 新一代流处理利器——Heron 论文笔记之Heron架构
Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...
随机推荐
- web端测试总结
1.数值型输入框: 条件:demcial(x,y) ,界面显示小数点到y位 通常要检查以下几点: (1)边界值:最大值.最小值.最大值+1.最小值-1 (2)位数:最小位数.最大位数.最小位数-1最 ...
- Java调用动态链接库so文件(传参以及处理返回值问题)
刚来到公司,屁股还没坐稳,老板把我叫到办公室,就让我做一个小程序.我瞬间懵逼了.对小程序一窍不通,还好通过学习小程序视频,两天的时间就做了一个云开发的小程序,但是领导不想核心的代码被别人看到,给了我一 ...
- 【idea】全局搜索、替换只显示100条的问题
没有修改之前 修改之后 如果用的是idea默认的快捷键,按Ctrl+Shift+A,然后输入Registry 如果是eclipse的快捷键
- GPU机器安装paddle
安装基础包 yum -y install epel-release yum -y install kernel-devel yum -y install dkms 编辑文件 /etc/default/ ...
- [转帖]TimesTen与Redis的对比
TimesTen与Redis的对比 2017-02-23 17:25:27 dingdingfish 阅读数 3682更多 分类专栏: TimesTen Oracle Redis In-Memory ...
- AtCoder-abc147 (题解)
A - Blackjack (水题) 题目链接 大致思路: 水题 B - Palindrome-philia (水题) 题目链接 大致思路: 由于整个串是回文串,只要判断前一半和后一半有多少个不同即可 ...
- Python基础(七)——文件和异常
1.1 读取整个文件 我们可以创建一个 test.txt 并写入一些内容,使用 Python 读文件操作,读出文本内容. with open(r'E:\test.txt') as file_objec ...
- python_dict json读写文件
命令汇总: json.dumps(obj) 将python数据转化为json Indent实现缩进,ensure_ascii 是否用ascii解析 json.loads(s) 将json数据转换为py ...
- 解决fiddler不能抓取firefox浏览器包的问题(转)
转自:https://blog.csdn.net/jimmyandrushking/article/details/80819103
- .NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化
原文:.NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化 最近我大幅度重构了我一个库的项目结构,使之使用最新的项目文件格式(基于 Microsoft.NET.Sdk)并使 ...