U-Net网络的Pytorch实现
1.文章原文地址
U-Net: Convolutional Networks for Biomedical Image Segmentation
2.文章摘要
普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.
3.网络结构

4.Pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary class unetConv2(nn.Module):
def __init__(self,in_size,out_size,is_batchnorm):
super(unetConv2,self).__init__() if is_batchnorm:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs) return outputs class unetUp(nn.Module):
def __init__(self,in_size,out_size,is_deconv):
super(unetUp,self).__init__()
self.conv=unetConv2(in_size,out_size,False)
if is_deconv:
self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
else:
self.up=nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1,inputs2):
outputs2=self.up(inputs2)
offset=outputs2.size()[2]-inputs1.size()[2]
padding=2*[offset//2,offset//2]
outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller return self.conv(torch.cat([outputs1,outputs2],1)) class unet(nn.Module):
def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
super(unet,self).__init__()
self.is_deconv=is_deconv
self.in_channels=in_channels
self.is_batchnorm=is_batchnorm
self.feature_scale=feature_scale filters=[64,128,256,512,1024]
filters=[int(x/self.feature_scale) for x in filters] #downsample
self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
self.maxpool1=nn.MaxPool2d(kernel_size=2) self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
self.maxpool2=nn.MaxPool2d(kernel_size=2) self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
self.maxpool3=nn.MaxPool2d(kernel_size=2) self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
self.maxpool4=nn.MaxPool2d(kernel_size=2) self.center=unetConv2(filters[3],filters[4],self.is_batchnorm) #umsampling
self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv) #final conv (without and concat)
self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1) def forward(self, inputs):
conv1=self.conv1(inputs)
maxpool1=self.maxpool1(conv1) conv2=self.conv2(maxpool1)
maxpool2=self.maxpool2(conv2) conv3=self.conv3(maxpool2)
maxpool3=self.maxpool3(conv3) conv4=self.conv4(maxpool3)
maxpool4=self.maxpool4(conv4) center=self.center(maxpool4)
up4=self.up_concat4(conv4,center)
up3=self.up_concat3(conv3,up4)
up2=self.up_concat2(conv2,up3)
up1=self.up_concat1(conv1,up2) final=self.final(up1) return final if __name__=="__main__":
model=unet(feature_scale=1)
print(summary(model,(3,572,572)))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 1,792
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
unetConv2-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
unetConv2-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
unetConv2-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
unetConv2-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
unetConv2-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
ReLU-42 [-1, 512, 54, 54] 0
Conv2d-43 [-1, 512, 52, 52] 2,359,808
ReLU-44 [-1, 512, 52, 52] 0
unetConv2-45 [-1, 512, 52, 52] 0
unetUp-46 [-1, 512, 52, 52] 0
ConvTranspose2d-47 [-1, 256, 104, 104] 524,544
Conv2d-48 [-1, 256, 102, 102] 1,179,904
ReLU-49 [-1, 256, 102, 102] 0
Conv2d-50 [-1, 256, 100, 100] 590,080
ReLU-51 [-1, 256, 100, 100] 0
unetConv2-52 [-1, 256, 100, 100] 0
unetUp-53 [-1, 256, 100, 100] 0
ConvTranspose2d-54 [-1, 128, 200, 200] 131,200
Conv2d-55 [-1, 128, 198, 198] 295,040
ReLU-56 [-1, 128, 198, 198] 0
Conv2d-57 [-1, 128, 196, 196] 147,584
ReLU-58 [-1, 128, 196, 196] 0
unetConv2-59 [-1, 128, 196, 196] 0
unetUp-60 [-1, 128, 196, 196] 0
ConvTranspose2d-61 [-1, 64, 392, 392] 32,832
Conv2d-62 [-1, 64, 390, 390] 73,792
ReLU-63 [-1, 64, 390, 390] 0
Conv2d-64 [-1, 64, 388, 388] 36,928
ReLU-65 [-1, 64, 388, 388] 0
unetConv2-66 [-1, 64, 388, 388] 0
unetUp-67 [-1, 64, 388, 388] 0
Conv2d-68 [-1, 21, 388, 388] 1,365
================================================================
Total params: 31,040,981
Trainable params: 31,040,981
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3158.15
Params size (MB): 118.41
Estimated Total Size (MB): 3280.31
参考
https://github.com/meetshah1995/pytorch-semseg
U-Net网络的Pytorch实现的更多相关文章
- 群等变网络的pytorch实现
CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...
- ResNet网络的Pytorch实现
1.文章原文地址 Deep Residual Learning for Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...
- GoogLeNet网络的Pytorch实现
1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...
- AlexNet网络的Pytorch实现
1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...
- VGG网络的Pytorch实现
1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...
- SegNet网络的Pytorch实现
1.文章原文地址 SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 2.文章摘要 语义分 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- PyTorch使用总览
PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...
随机推荐
- Django的小记
大致按流程列出来 在pycham中创建Django project时要确定机器上的版本及你要用的版本,机器上一般情况下默认最新版本2.1(2018年11月),根据需要下载相应版本 创建好工程后就要创建 ...
- idea设置创建类的注释模板
打开settings>>Editor>>File and Code Templates>>Includes>>File Header
- webpack的配置 @3.6.0
1.下载对应版本的webpack npm install webpack@3.6.0 -save --dev 2.新建webpack.config.js文件,目录结构↑ 3. >>webp ...
- Zero-shot Learning / One-shot Learning / Few-shot Learning
Zero-shot Learning / One-shot Learning / Few-shot Learning Learning类型:Zero-shot Learning.One-shot Le ...
- Asp.Net Core中创建多DbContext并迁移到数据库
在我们的项目中我们有时候需要在我们的项目中创建DbContext,而且这些DbContext之间有明显的界限,比如系统中两个DbContext一个是和整个数据库的权限相关的内容而另外一个DbConte ...
- pycharm 使用black
pycharm 使用black The Uncompromising Code Formatter By using Black, you agree to cede control over min ...
- Python进阶: Decorator 装饰器你太美
函数 -> 装饰器 函数的4个核心概念 1.函数可以赋与变量 def func(message): print('Got a message: {}'.format(message)) send ...
- k8s 启动pod的问题
版本: k8s 1.5 docker 1.3 CentOS 7.6 使用命令 kubectl get pods输出no resources.解决方法是修改 apiserver 的配置文件 vim /e ...
- Vue使用指南(一)
Vue Vue:前台框架 渐进式JavaScript框架 渐进式:vue可以控制页面的一个局部,vue也可以控制整个页面,vue也能控制整个前端项目 -- 根据项目需求,来决定vue控制项目的 ...
- 解决batik使用JScrollPane显示svg图滚动条不显示的问题
// 必须使用batik提供的JSVGScrollPane,使用swing自己的组件JScrollPane初始化时滚动条不会显示. JSVGScrollPane svgJScrollPane = ne ...