[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络
越简单越接近本质。

参考资料
U-Net: Convolutional Networks for Biomedical Image Segmentation
Abstract & Introduction
论文中有几个关键词:
- contracting path 收缩路径;
- expansive path 扩张路径;
- precise localization 更精确的位置信息;
- overlap-tile 边界镜像翻转;
- random elastic deformations 随机弹性形变;
- invariance 尺度不变性;
- touching cells 指距离很近的两个细胞;
- seamless tilling 无缝拼接;
好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。
首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:
Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)
- 非常慢,计算冗余(sliding window的毛病大家都懂);
- 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。
作者的输入多层特征的思想是受以下论文启发:
- Hypercolumns for object segmentation and fine-grained localization (2014)
- Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)
这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。
U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。
由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture
网络结构
使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。
全部使用ReLU激活函数。
权值初始化使用何恺明的方法:
Surpassing humanlevel performance on imagenet classification
具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)
训练
在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。
损失函数就是pixel-wise soft-max + cross_entropy了。
数据增强
随机弹性形变和weight map:
随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。
随机弹性形变的目的是让网络有invariance(尺度不变性)。
那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码
最后来看看代码吧:https://github.com/milesial/Pytorch-UNet
整体模型:
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
细节部分:
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
训练:
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()
[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章
- [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation
写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...
- 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs
论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...
- Deep Learning论文笔记之(八)Deep Learning最新综述
Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...
- 论文笔记:Mastering the game of Go with deep neural networks and tree search
Mastering the game of Go with deep neural networks and tree search Nature 2015 这是本人论文笔记系列第二篇 Nature ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- AI理论学习笔记(一):深度学习的前世今生
AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- 论文笔记之:Visual Tracking with Fully Convolutional Networks
论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015 CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...
- Twitter 新一代流处理利器——Heron 论文笔记之Heron架构
Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...
随机推荐
- Mac Pro 2015休眠掉电解决办法
硬件:Mac Pro 2015 系统:MacOs Mojave 10.14.3 问题:合盖的时候,休眠1小时掉电10%,由于之前是128G原装盘不会有这个问题,后面购买了M.2转接卡,更换1T Int ...
- c++生成数据程序模板
in.cpp: #include<bits/stdc++.h> #define random(a,b) rand()%(b-a+1)+a using namespace std; cons ...
- C# .NET “公钥证书” (.cer .pem)转换为 RSACryptoServiceProvider 对象。导出“公钥”
“公钥证书” .cer 文件是直接可以用X509Certificate2 对象来读取的,但 .cer 文件 不便于存储. “公钥证书” .pem 文件内容如下: -----BEGIN CERTIFIC ...
- 搭建mqtt服务器apollo
使用的apollo,官网太慢,附上百度云下载地址: 链接:https://pan.baidu.com/s/1NIq6R71hlyPuaUBwPoMPNg 提取码:36vw 原文链接:https://b ...
- RocketMQ 4.5.1 单机环境搭建以及生产消费测试
为了学习和方便测试,总是要启动一个单机版的.下载 http://rocketmq.apache.org/dowloading/releases/ 1. 要先配置环境变量 ROCKETMQ_HOME E ...
- Word 查找替换高级玩法系列之 -- 通配符大全B篇
未完 ...... 点击访问原文(进入后根据右侧标签,快速定位到本文)
- spring整合RabbitMQ
今天就来康康spring怎么整合RabbitMQ 注意一点,在发送消息的时候对template进行配置mandatory=true保证监听有效 生产端还可以配置其他属性,比如发送重试,超时时间.次数. ...
- 前台调用微信接口成功还报Network Error
前台 vue+springboot项目 this.api({ url:"https://.....",//微信路径 method:"post", param ...
- git 学习笔记 --从远程库克隆
上次我们讲了先有本地库,后有远程库的时候,如何关联远程库. 现在,假设我们从零开发,那么最好的方式是先创建远程库,然后,从远程库克隆. 首先,登陆GitHub,创建一个新的仓库,名字叫gitskill ...
- ASP.NET MVC 页面静态化操作的思路
本文主要讲述了在asp.net mvc中,页面静态化的几种思路和方法.对于网站来说,生成纯html静态页面除了有利于seo外,还可以减轻网站的负载能力和提高网站性能.在asp.net mvc中,视图的 ...