1.文章原文地址

SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

2.文章摘要

语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

3.网络结构

4.Pytorch实现

 import torch.nn as nn
import torch class conv2DBatchNormRelu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
bias=True,dilation=1,is_batchnorm=True):
super(conv2DBatchNormRelu,self).__init__()
if is_batchnorm:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
bias=bias,dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
else:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, dilation=dilation),
nn.ReLU(inplace=True)
) def forward(self,inputs):
outputs=self.cbr_unit(inputs)
return outputs class segnetDown2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown2,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetDown3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown3,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetUp2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp2,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
return outputs class segnetUp3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp3,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
return outputs class segnet(nn.Module):
def __init__(self,in_channels=3,num_classes=21):
super(segnet,self).__init__()
self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
self.down2=segnetDown2(64,128)
self.down3=segnetDown3(128,256)
self.down4=segnetDown3(256,512)
self.down5=segnetDown3(512,512) self.up5=segnetUp3(512,512)
self.up4=segnetUp3(512,256)
self.up3=segnetUp3(256,128)
self.up2=segnetUp2(128,64)
self.up1=segnetUp2(64,64)
self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) def forward(self,inputs):
down1,indices_1,unpool_shape1=self.down1(inputs)
down2,indices_2,unpool_shape2=self.down2(down1)
down3,indices_3,unpool_shape3=self.down3(down2)
down4,indices_4,unpool_shape4=self.down4(down3)
down5,indices_5,unpool_shape5=self.down5(down4) up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
outputs=self.finconv(up1) return outputs if __name__=="__main__":
inputs=torch.ones(1,3,224,224)
model=segnet()
print(model(inputs).size())
print(model)

参考

https://github.com/meetshah1995/pytorch-semseg

SegNet网络的Pytorch实现的更多相关文章

  1. 群等变网络的pytorch实现

    CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...

  2. U-Net网络的Pytorch实现

    1.文章原文地址 U-Net: Convolutional Networks for Biomedical Image Segmentation 2.文章摘要 普遍认为成功训练深度神经网络需要大量标注 ...

  3. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  4. GoogLeNet网络的Pytorch实现

    1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...

  5. AlexNet网络的Pytorch实现

    1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...

  6. VGG网络的Pytorch实现

    1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...

  7. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  8. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  9. PyTorch使用总览

    PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...

随机推荐

  1. 微服务Consul系列之服务注册与服务发现

    在进行服务注册之前先确认集群是否建立,关于服务注册可以看上篇微服务Consul系列之集群搭建的介绍,两种注册方式:一种是注册HTTP API.另一种是通过配置文件定义,下面讲解的是基于后者配置文件定义 ...

  2. 同时使用Redis缓存和Google Guava本地缓存注意事项(深拷贝和浅拷贝)

    目录 1.问题场景及说明 2.Redis 缓存是深拷贝 3.Guava本地缓存直接获取则是浅拷贝 4.如何实现Guava获取本地缓存是深拷贝? 1.问题场景及说明 系统中同时使用 Redis 缓存和 ...

  3. 任务调度之Quartz.Net可视化界面

    上一篇关于任务调度Quartz.Net的文章中介绍了其三个核心对象IScheduler.IJob和ITrigger,我们已经知道了其基本的使用方法,可以在控制台当中运行监控.但是在实际中我们往往需要有 ...

  4. samtools获取uniq reads

    参考地址: https://www.biostars.org/p/56246/ -q INT only include reads with mapping quality >= INT [0] ...

  5. jdk 7&8 new features

    7 Diamond Operator(菱形操作符) You can omitted the type declaration of the right when working with Generi ...

  6. golang使用一个二叉树来实现一个插入排序

    思路不太好理解,请用断点 package main import "fmt" type tree struct { value int left, right *tree } fu ...

  7. 数据结构-链式栈c++

    栈的最基本特点先进后出,本文简单介绍一下用c++写的链式栈 头文件 #ifndef LINKEDSTACK_H #define LINKEDSTACK_H template<class T> ...

  8. vue 写一个炫酷的轮播图

    效果如上图: 原理: 1.利用css 的 transform 和一些其他的属性,先选五张将图片位置拍列好,剩余的隐藏 2.利用 js 动态切换类名,达到切换效果 css代码如下 .swiper-cer ...

  9. Spring Cloud Alibaba学习笔记(2) - Nacos服务发现

    1.什么是Nacos Nacos的官网对这一问题进行了详细的介绍,通俗的来说: Nacos是一个服务发现组件,同时也是一个配置服务器,它解决了两个问题: 1.服务A如何发现服务B 2.管理微服务的配置 ...

  10. springboot之手动控制事务

    一.事务的重要性,相信在实际开发过程中,都有很深的了解了.但是存在一个问题我们经常在开发的时候一般情况下都是用的注解的方式来进行事务的控制,说白了基于spring的7种事务控制方式来进行事务的之间的协 ...