SegNet网络的Pytorch实现
1.文章原文地址
SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation
2.文章摘要
语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。
3.网络结构

4.Pytorch实现
import torch.nn as nn
import torch class conv2DBatchNormRelu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
bias=True,dilation=1,is_batchnorm=True):
super(conv2DBatchNormRelu,self).__init__()
if is_batchnorm:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
bias=bias,dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
else:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, dilation=dilation),
nn.ReLU(inplace=True)
) def forward(self,inputs):
outputs=self.cbr_unit(inputs)
return outputs class segnetDown2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown2,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetDown3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown3,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetUp2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp2,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
return outputs class segnetUp3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp3,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
return outputs class segnet(nn.Module):
def __init__(self,in_channels=3,num_classes=21):
super(segnet,self).__init__()
self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
self.down2=segnetDown2(64,128)
self.down3=segnetDown3(128,256)
self.down4=segnetDown3(256,512)
self.down5=segnetDown3(512,512) self.up5=segnetUp3(512,512)
self.up4=segnetUp3(512,256)
self.up3=segnetUp3(256,128)
self.up2=segnetUp2(128,64)
self.up1=segnetUp2(64,64)
self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) def forward(self,inputs):
down1,indices_1,unpool_shape1=self.down1(inputs)
down2,indices_2,unpool_shape2=self.down2(down1)
down3,indices_3,unpool_shape3=self.down3(down2)
down4,indices_4,unpool_shape4=self.down4(down3)
down5,indices_5,unpool_shape5=self.down5(down4) up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
outputs=self.finconv(up1) return outputs if __name__=="__main__":
inputs=torch.ones(1,3,224,224)
model=segnet()
print(model(inputs).size())
print(model)
参考
https://github.com/meetshah1995/pytorch-semseg
SegNet网络的Pytorch实现的更多相关文章
- 群等变网络的pytorch实现
CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...
- U-Net网络的Pytorch实现
1.文章原文地址 U-Net: Convolutional Networks for Biomedical Image Segmentation 2.文章摘要 普遍认为成功训练深度神经网络需要大量标注 ...
- ResNet网络的Pytorch实现
1.文章原文地址 Deep Residual Learning for Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...
- GoogLeNet网络的Pytorch实现
1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...
- AlexNet网络的Pytorch实现
1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...
- VGG网络的Pytorch实现
1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- PyTorch使用总览
PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...
随机推荐
- Java创建线程的两种方法
大多数情况,通过实例化一个Thread对象来创建一个线程.Java定义了两种方式: 实现Runnable 接口: 可以继承Thread类. 下面的两小节依次介绍了每一种方式. 实现Runnable接口 ...
- PHP如何访问数据库集群
一般常见的有三种做法, 1,自动判断sql是否为读,来选择数据库的连接: 实例化php DB类的时候,需要一次连接两台服务器,然后根据slq选择不同的连接,举个例子: $link_w = mysql_ ...
- [转帖]Linux系统进程的知识总结,进程与线程之间的纠葛...
Linux系统进程的知识总结,进程与线程之间的纠葛... https://cloud.tencent.com/developer/article/1500509 当一个程序开始执行后,在开始执行到执行 ...
- WUSTOJ 1327: Lucky Numbers(Java)
题目链接:1327: Lucky Numbers Description A lucky number is made by the following rules: Given a positive ...
- 列主元消去法&全主元消去法——Java实现
Gauss.java package Gauss; /** * @description TODO 父类,包含高斯列主元消去法和全主元消去法的共有属性和方法 * @author PengHao * @ ...
- Cortex_m7内核cache深入了解和应用
一,cache概述 从下图可以看出,从M7内核才开始有的cache,这对于从M0,M3,M4一路走来的小伙伴来说,多了一个cache就多了一个障碍. Cortex-M7 core with 32K/3 ...
- JPA 一对一 一对多 多对一 多对多配置
1 JPA概述 1.1 JPA是什么 JPA (Java Persistence API) Java持久化API.是一套Sun公司 Java官方制定的ORM 方案,是规范,是标准 ,sun公司自己并没 ...
- 将图片画到canvas 上的几种方法(转)
转自:https://blog.csdn.net/qq_15009739/article/details/82809525
- C#避免WinForm窗体假死
WinForm窗体在使用过程中如果因为程序等待时间太久而导致窗体本身假死无法控制,会严重影响用户的体验,这种情况大多是UI线程被耗时长的代码操作占用所致,可以新开一个线程用来完成耗时长的操作,然后再将 ...
- Atcoder&CodeForces杂题11.6
Preface NOIP前突然不知道做什么,感觉思维有点江僵化,就在vjudge上随便组了6道ABC D+CF Div2 C/D做,发现比赛质量还不错,知识点涉及广,难度有梯度,码量稍小,思维较多. ...