1.文章原文地址

U-Net: Convolutional Networks for Biomedical Image Segmentation

2.文章摘要

普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

3.网络结构

4.Pytorch实现

 import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary class unetConv2(nn.Module):
def __init__(self,in_size,out_size,is_batchnorm):
super(unetConv2,self).__init__() if is_batchnorm:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs) return outputs class unetUp(nn.Module):
def __init__(self,in_size,out_size,is_deconv):
super(unetUp,self).__init__()
self.conv=unetConv2(in_size,out_size,False)
if is_deconv:
self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
else:
self.up=nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1,inputs2):
outputs2=self.up(inputs2)
offset=outputs2.size()[2]-inputs1.size()[2]
padding=2*[offset//2,offset//2]
outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller return self.conv(torch.cat([outputs1,outputs2],1)) class unet(nn.Module):
def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
super(unet,self).__init__()
self.is_deconv=is_deconv
self.in_channels=in_channels
self.is_batchnorm=is_batchnorm
self.feature_scale=feature_scale filters=[64,128,256,512,1024]
filters=[int(x/self.feature_scale) for x in filters] #downsample
self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
self.maxpool1=nn.MaxPool2d(kernel_size=2) self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
self.maxpool2=nn.MaxPool2d(kernel_size=2) self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
self.maxpool3=nn.MaxPool2d(kernel_size=2) self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
self.maxpool4=nn.MaxPool2d(kernel_size=2) self.center=unetConv2(filters[3],filters[4],self.is_batchnorm) #umsampling
self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv) #final conv (without and concat)
self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1) def forward(self, inputs):
conv1=self.conv1(inputs)
maxpool1=self.maxpool1(conv1) conv2=self.conv2(maxpool1)
maxpool2=self.maxpool2(conv2) conv3=self.conv3(maxpool2)
maxpool3=self.maxpool3(conv3) conv4=self.conv4(maxpool3)
maxpool4=self.maxpool4(conv4) center=self.center(maxpool4)
up4=self.up_concat4(conv4,center)
up3=self.up_concat3(conv3,up4)
up2=self.up_concat2(conv2,up3)
up1=self.up_concat1(conv1,up2) final=self.final(up1) return final if __name__=="__main__":
model=unet(feature_scale=1)
print(summary(model,(3,572,572)))
 ----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 1,792
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
unetConv2-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
unetConv2-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
unetConv2-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
unetConv2-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
unetConv2-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
ReLU-42 [-1, 512, 54, 54] 0
Conv2d-43 [-1, 512, 52, 52] 2,359,808
ReLU-44 [-1, 512, 52, 52] 0
unetConv2-45 [-1, 512, 52, 52] 0
unetUp-46 [-1, 512, 52, 52] 0
ConvTranspose2d-47 [-1, 256, 104, 104] 524,544
Conv2d-48 [-1, 256, 102, 102] 1,179,904
ReLU-49 [-1, 256, 102, 102] 0
Conv2d-50 [-1, 256, 100, 100] 590,080
ReLU-51 [-1, 256, 100, 100] 0
unetConv2-52 [-1, 256, 100, 100] 0
unetUp-53 [-1, 256, 100, 100] 0
ConvTranspose2d-54 [-1, 128, 200, 200] 131,200
Conv2d-55 [-1, 128, 198, 198] 295,040
ReLU-56 [-1, 128, 198, 198] 0
Conv2d-57 [-1, 128, 196, 196] 147,584
ReLU-58 [-1, 128, 196, 196] 0
unetConv2-59 [-1, 128, 196, 196] 0
unetUp-60 [-1, 128, 196, 196] 0
ConvTranspose2d-61 [-1, 64, 392, 392] 32,832
Conv2d-62 [-1, 64, 390, 390] 73,792
ReLU-63 [-1, 64, 390, 390] 0
Conv2d-64 [-1, 64, 388, 388] 36,928
ReLU-65 [-1, 64, 388, 388] 0
unetConv2-66 [-1, 64, 388, 388] 0
unetUp-67 [-1, 64, 388, 388] 0
Conv2d-68 [-1, 21, 388, 388] 1,365
================================================================
Total params: 31,040,981
Trainable params: 31,040,981
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3158.15
Params size (MB): 118.41
Estimated Total Size (MB): 3280.31

参考

https://github.com/meetshah1995/pytorch-semseg

U-Net网络的Pytorch实现的更多相关文章

  1. 群等变网络的pytorch实现

    CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...

  2. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  3. GoogLeNet网络的Pytorch实现

    1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...

  4. AlexNet网络的Pytorch实现

    1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...

  5. VGG网络的Pytorch实现

    1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...

  6. SegNet网络的Pytorch实现

    1.文章原文地址 SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 2.文章摘要 语义分 ...

  7. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  8. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  9. PyTorch使用总览

    PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...

随机推荐

  1. Python unittest框架实现appium登录

    import unittest from appium.webdriver import webdriver from ddt import data,ddt,unpack class MyTestC ...

  2. 高级UI-MD动画

    MD动画是谷歌推出的一种动画效果,其实现的效果能让用户看着很是舒服,符合MD动画的动画,有很强的用户交互体验 Touch Feedback(触摸反馈) 在触摸反馈这一块,用的最多的就是水波纹效果,而水 ...

  3. Teaset-React Native UI 组件库

    GitHub地址 https://github.com/rilyu/teaset/blob/master/docs/cn/README.md React Native UI 组件库, 超过 20 个纯 ...

  4. JIRA数据库的迁移,从HSQL到MYSQL/Oracle

    Jira数据库迁移,从HSQL到MYSQL 通过JIRA管理员登录,进入“管理员页面”,“系统”--“导入&导出”,以XML格式备份数据. 在MySQL中创建Schema,命名为jira 关闭 ...

  5. PHP中文名加密

    <?php function encryptName($name) { $encrypt_name = ''; //判断是否包含中文字符 if(preg_match("/[\x{4e0 ...

  6. [转帖]Java 2019 生态圈使用报告,这结果你赞同吗?

    Java 2019 生态圈使用报告,这结果你赞同吗? http://www.51testing.com/html/94/n-4462794.html 发表于:2019-10-15 17:10  作者: ...

  7. 对比JPA 和Hibernate 和 Mybatis的区别

    1.JPA.Hibernate.Mybatis简单了解 1.JPA:本身是一种ORM规范,不是ORM框架.由各大ORM框架提供实现. 2.Hibernate:目前最流行的ORM框架,设计灵巧,文档丰富 ...

  8. Python07之分支和循环2(if...else、if...elif...else)

    一:if语句具体语法: if 表达式: 语句块 (表达式可以是一个布尔值或变量,也可以为一个逻辑表达式或比较表达式,表达式为真(即不为0即可,见下方实例),则运行语句块:表达式为假,则跳过语句块,继续 ...

  9. Java 总结篇1

    初始Java 1.Java的特点: ① 跨平台(字节码文件可以在任何具有Java虚拟机的计算机或者电子设备上运行,Java虚拟机中的Java解释器负责将字节码文件解释成特定的机器码进行运行) ② 简单 ...

  10. Go语言学习笔记(9)——接口类型

    接口 Go 语言提供了另外一种数据类型即接口,它把所有的具有共性的方法定义在一起,任何其他类型只要实现了这些方法就是实现了这个接口. /* 定义接口 */ type interface_name in ...