1.文章原文地址

SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

2.文章摘要

语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

3.网络结构

4.Pytorch实现

 import torch.nn as nn
import torch class conv2DBatchNormRelu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
bias=True,dilation=1,is_batchnorm=True):
super(conv2DBatchNormRelu,self).__init__()
if is_batchnorm:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
bias=bias,dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
else:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, dilation=dilation),
nn.ReLU(inplace=True)
) def forward(self,inputs):
outputs=self.cbr_unit(inputs)
return outputs class segnetDown2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown2,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetDown3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown3,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetUp2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp2,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
return outputs class segnetUp3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp3,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
return outputs class segnet(nn.Module):
def __init__(self,in_channels=3,num_classes=21):
super(segnet,self).__init__()
self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
self.down2=segnetDown2(64,128)
self.down3=segnetDown3(128,256)
self.down4=segnetDown3(256,512)
self.down5=segnetDown3(512,512) self.up5=segnetUp3(512,512)
self.up4=segnetUp3(512,256)
self.up3=segnetUp3(256,128)
self.up2=segnetUp2(128,64)
self.up1=segnetUp2(64,64)
self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) def forward(self,inputs):
down1,indices_1,unpool_shape1=self.down1(inputs)
down2,indices_2,unpool_shape2=self.down2(down1)
down3,indices_3,unpool_shape3=self.down3(down2)
down4,indices_4,unpool_shape4=self.down4(down3)
down5,indices_5,unpool_shape5=self.down5(down4) up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
outputs=self.finconv(up1) return outputs if __name__=="__main__":
inputs=torch.ones(1,3,224,224)
model=segnet()
print(model(inputs).size())
print(model)

参考

https://github.com/meetshah1995/pytorch-semseg

SegNet网络的Pytorch实现的更多相关文章

  1. 群等变网络的pytorch实现

    CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...

  2. U-Net网络的Pytorch实现

    1.文章原文地址 U-Net: Convolutional Networks for Biomedical Image Segmentation 2.文章摘要 普遍认为成功训练深度神经网络需要大量标注 ...

  3. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  4. GoogLeNet网络的Pytorch实现

    1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...

  5. AlexNet网络的Pytorch实现

    1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...

  6. VGG网络的Pytorch实现

    1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...

  7. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  8. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  9. PyTorch使用总览

    PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...

随机推荐

  1. 【VS开发】程序如何捕捉signal函数参数中指定的信号

    当说到signal的功能时,我们都知道它会捕捉我们所指定的信号,然后调用我们所指定的信号处理函数.但它是如何捕捉我们指定的信号的呢?下面我就以msdn上关于signal的example为例,说明sig ...

  2. [SQL] - Attempted to read or write protected memory. This is often an indication that other memory is corrupt. 问题之解决

    场景: 使用 Oracle.DataAccess.dll 访问数据库时,OracleDataAdapter 执行失败. 异常: System.AccessViolationException was ...

  3. P5200 [USACO19JAN]Sleepy Cow Sorting

    P5200 [USACO19JAN]Sleepy Cow Sorting 题目描述 Farmer John正在尝试将他的N头奶牛(1≤N≤10^5),方便起见编号为1…N,在她们前往牧草地吃早餐之前排 ...

  4. 打开python 交互式模式

    pip install jupyter jupyter notebook --ip=127.0.0.1 --port=8888

  5. Python进阶:并发编程之Futures

    区分并发和并行 并发(Concurrency). 由于Python 的解释器并不是线程安全的,为了解决由此带来的 race condition 等问题,Python 便引入了全局解释器锁,也就是同一时 ...

  6. matplotlib实例笔记

    下面的图型是在一幅画布上建立的四个球员相关数据的极坐标图 关于这个图的代码如下: #_*_coding:utf-8_*_ import numpy as np import matplotlib.py ...

  7. C# 字符串补位方法

    string i=9; 方法1:Console.WriteLine(i.ToString("D5")); 方法2:Console.WriteLine(i.ToString().Pa ...

  8. 遇到 GLFW 我的demo可以运行 但是公司的程序调用我的so运行不起来

    //to do 原       因:  发现 自身demo的程序的shaders更新了  但是公司程序却没有更新 解决办法:更新公司程序的shaders 为最新版本 吸取的教训: 不仅仅要更新公司程序 ...

  9. [ICCV 2019] Weakly Supervised Object Detection With Segmentation Collaboration

    新在ICCV上发的弱监督物体检测文章,偷偷高兴一下,贴出我的poster,最近有点忙,话不多说,欢迎交流- https://arxiv.org/pdf/1904.00551.pdf http://op ...

  10. SDcms1.8代码审计

    由于工作原因,分析了很多的cms也都写过文章,不过觉得好像没什么骚操作都是网上的基本操作,所以也就没发表在网站上,都保存在本地.最近突然发现自己博客中实战的东西太少了,决定将以前写的一些文章搬过来,由 ...