[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络
越简单越接近本质。

参考资料
U-Net: Convolutional Networks for Biomedical Image Segmentation
Abstract & Introduction
论文中有几个关键词:
- contracting path 收缩路径;
- expansive path 扩张路径;
- precise localization 更精确的位置信息;
- overlap-tile 边界镜像翻转;
- random elastic deformations 随机弹性形变;
- invariance 尺度不变性;
- touching cells 指距离很近的两个细胞;
- seamless tilling 无缝拼接;
好了,说完这些关键词我们来看看这篇论文,这篇论文和他的结构一样简单易懂,很能说明问题。
首先,作者主要拿自己的网络和一个基于sliding window的方法做对比,作者先diss了一下这个方法存在以下问题:
Deep neural networks segment neuronal membranes in electron microscopy images (NIPS2012)
- 非常慢,计算冗余(sliding window的毛病大家都懂);
- 在位置精确性和特征提取之间存在一个平衡,因为更多的特征意味着更多的max-pooling,则会丢失掉更多位置信息。
作者的输入多层特征的思想是受以下论文启发:
- Hypercolumns for object segmentation and fine-grained localization (2014)
- Image segmentation with cascaded hierarchical models and logistic disjunctive normal networks (2013)
这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。
U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。
由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:

Architecture
网络结构
使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。
全部使用ReLU激活函数。
权值初始化使用何恺明的方法:
Surpassing humanlevel performance on imagenet classification
具体做法就是一个标准差满足sqrt(2/N)的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)
训练
在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。
损失函数就是pixel-wise soft-max + cross_entropy了。
数据增强
随机弹性形变和weight map:
随机弹性形变就是先用3x3的粗网格初始化随机形变,然后从标准差为10pixel的高斯分布中初始化随机位移矢量,再用bicubic双三次插值来计算每个像素的位移。
随机弹性形变的目的是让网络有invariance(尺度不变性)。
那么weight map是为了强制让网络学习touching cells之间的背景,这些位于touching cells之间的背景在损失函数的权重很高,如下图:

weight map的具体计算方式如下:

代码
最后来看看代码吧:https://github.com/milesial/Pytorch-UNet
整体模型:
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
细节部分:
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
训练:
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()
[AI] 论文笔记 - U-Net 简单而又接近本质的分割网络的更多相关文章
- [AI] 论文笔记 - CVPR2018 Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation
写在前面 原始视频(30fps) 补帧后的视频(240fps) 本文是博主在做实验的过程中使用到的方法,刚好也做为了本科毕设的翻译文章,现在把它搬运到博客上来,因为觉得这篇文章的思路真的不错. 这篇文 ...
- 【论文笔记】Learning Fashion Compatibility with Bidirectional LSTMs
论文:<Learning Fashion Compatibility with Bidirectional LSTMs> 论文地址:https://arxiv.org/abs/1707.0 ...
- Deep Learning论文笔记之(八)Deep Learning最新综述
Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...
- 论文笔记:Mastering the game of Go with deep neural networks and tree search
Mastering the game of Go with deep neural networks and tree search Nature 2015 这是本人论文笔记系列第二篇 Nature ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- AI理论学习笔记(一):深度学习的前世今生
AI理论学习笔记(一):深度学习的前世今生 大家还记得以深度学习技术为基础的电脑程序AlphaGo吗?这是人类历史中在某种意义的第一次机器打败人类的例子,其最大的魅力就是深度学习(Deep Learn ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- 论文笔记之:Visual Tracking with Fully Convolutional Networks
论文笔记之:Visual Tracking with Fully Convolutional Networks ICCV 2015 CUHK 本文利用 FCN 来做跟踪问题,但开篇就提到并非将其看做 ...
- Twitter 新一代流处理利器——Heron 论文笔记之Heron架构
Twitter 新一代流处理利器--Heron 论文笔记之Heron架构 标签(空格分隔): Streaming-process realtime-process Heron Architecture ...
随机推荐
- [数据结构 - 第3章] 线性表之单链表(C++实现)
一.类定义 单链表类的定义如下: #ifndef SIGNALLIST_H #define SIGNALLIST_H typedef int ElemType; /* "ElemType类型 ...
- 局域网-断网&劫持(kali)
1.查看局域网中的主机 fping –asg 192.168.1.0/24 2.断网 arpspoof -i wlan0 -t 192.168.100 192.168.1.1 (arpspoof - ...
- Mongodb 的事务在python中的操作
代码实现如下: import pymongo mgClient = pymongo.MongoClient("ip", "port") session = mg ...
- Android 系统架构 和 各个版本代号介绍
一.Android 系统架构: 1. linux内核层Android 基于Linux内核,为Android设备的各种硬件提供底层驱动 比如: 显示驱动.音频.照相机.蓝牙.Wi-Fi驱动,电源管理等 ...
- mPass多租户系统微服务开发平台
目录 项目总体架构图 基于SpringBoot2.x.SpringCloud并采用前后端分离的企业级微服务,多租户系统架构微服务开发平台 mPaaS(Microservice PaaS)为租户业务开发 ...
- java之hibernate之组合主键映射
1.在应用中经常会有主键是由2个或多个字段组合而成的.比如成绩表: 第一种方式:把主键写为单独的类 2.类的设计:studentId,subjectId ,这两个主键是一个组件.所以可以采用组件映射的 ...
- 2019 小米java面试笔试题 (含面试题解析)
本人5年开发经验.18年年底开始跑路找工作,在互联网寒冬下成功拿到阿里巴巴.今日头条.小米等公司offer,岗位是Java后端开发,因为发展原因最终选择去了小米,入职一年时间了,也成为了面试官,之 ...
- loj#10067 构造完全图(最小生成树)
题目 loj#10067 构造完全图 解析 和kruscal类似,我们要构造一个完全图,考虑往这颗最小生成树里加边 我们先把每一条边存下来, 把两个端点分别放在不同的集合内,记录每个集合的大小,然后做 ...
- aria2 添加任务后一直在等待,不进行下载是什么情况?
https://www.v2ex.com/t/567014 跑 aria2 的机器配置比较低,是 j1900+4G 的小机器,系统是 ubuntu18.04 ,所有的任务都是 bt 下载.aria2 ...
- 练习bloc , 动画
有点意思, import 'package:flutter/material.dart'; import 'package:rxdart/rxdart.dart'; main()=>runApp ...