SegNet网络的Pytorch实现
1.文章原文地址
SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation
2.文章摘要
语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。
3.网络结构

4.Pytorch实现
import torch.nn as nn
import torch class conv2DBatchNormRelu(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
bias=True,dilation=1,is_batchnorm=True):
super(conv2DBatchNormRelu,self).__init__()
if is_batchnorm:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
bias=bias,dilation=dilation),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
else:
self.cbr_unit=nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias, dilation=dilation),
nn.ReLU(inplace=True)
) def forward(self,inputs):
outputs=self.cbr_unit(inputs)
return outputs class segnetDown2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown2,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetDown3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetDown3,self).__init__()
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) def forward(self,inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
unpooled_shape=outputs.size()
outputs,indices=self.maxpool_with_argmax(outputs)
return outputs,indices,unpooled_shape class segnetUp2(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp2,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
return outputs class segnetUp3(nn.Module):
def __init__(self,in_channels,out_channels):
super(segnetUp3,self).__init__()
self.unpool=nn.MaxUnpool2d(2,2)
self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) def forward(self,inputs,indices,output_shape):
outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
outputs=self.conv1(outputs)
outputs=self.conv2(outputs)
outputs=self.conv3(outputs)
return outputs class segnet(nn.Module):
def __init__(self,in_channels=3,num_classes=21):
super(segnet,self).__init__()
self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
self.down2=segnetDown2(64,128)
self.down3=segnetDown3(128,256)
self.down4=segnetDown3(256,512)
self.down5=segnetDown3(512,512) self.up5=segnetUp3(512,512)
self.up4=segnetUp3(512,256)
self.up3=segnetUp3(256,128)
self.up2=segnetUp2(128,64)
self.up1=segnetUp2(64,64)
self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) def forward(self,inputs):
down1,indices_1,unpool_shape1=self.down1(inputs)
down2,indices_2,unpool_shape2=self.down2(down1)
down3,indices_3,unpool_shape3=self.down3(down2)
down4,indices_4,unpool_shape4=self.down4(down3)
down5,indices_5,unpool_shape5=self.down5(down4) up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
outputs=self.finconv(up1) return outputs if __name__=="__main__":
inputs=torch.ones(1,3,224,224)
model=segnet()
print(model(inputs).size())
print(model)
参考
https://github.com/meetshah1995/pytorch-semseg
SegNet网络的Pytorch实现的更多相关文章
- 群等变网络的pytorch实现
CNN对于旋转不具有等变性,对于平移有等变性,data augmentation的提出就是为了解决这个问题,但是data augmentation需要很大的模型容量,更多的迭代次数才能够在训练数据集合 ...
- U-Net网络的Pytorch实现
1.文章原文地址 U-Net: Convolutional Networks for Biomedical Image Segmentation 2.文章摘要 普遍认为成功训练深度神经网络需要大量标注 ...
- ResNet网络的Pytorch实现
1.文章原文地址 Deep Residual Learning for Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...
- GoogLeNet网络的Pytorch实现
1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...
- AlexNet网络的Pytorch实现
1.文章原文地址 ImageNet Classification with Deep Convolutional Neural Networks 2.文章摘要 我们训练了一个大型的深度卷积神经网络用于 ...
- VGG网络的Pytorch实现
1.文章原文地址 Very Deep Convolutional Networks for Large-Scale Image Recognition 2.文章摘要 在这项工作中,我们研究了在大规模的 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- pytorch预训练
Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...
- PyTorch使用总览
PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...
随机推荐
- EC11编码器的使用方法
1. EC11编码器的原理图如下 2. 旋转的时候,波形如下,EC11转1格,产生一个上升沿的中断,思路就是检测AX4-1的上升沿中断(平时是低电平),进入中断服务函数,检测AX4-2的电平,低电平逆 ...
- AI - TensorFlow - 示例05:保存和恢复模型
保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...
- docker 安装 tomcat8
docker hub 查找 tomcat meiya@meiya:/etc/docker$ docker search tomcat NAME DESCRIPTION STARS OFFICIAL A ...
- java的ReentrantLock类详解
ReentrantLock 能用于更精细化的加锁的Java类, 通过它能更清楚了解Java的锁机制 ReentrantLock 类的集成关系有点复杂, 既有内部类, 还有多重继承关系 类的定义 pub ...
- Python基础--软件Anaconda的下载与安装
1.Anaconda软件的优点: Anaconda指的是一个开源的Python发行版本开发平台,在进行Python开发上方便简洁,有利于初步学习和实践深度学习. 2.Anaconda软件的下载: An ...
- [转帖]Hive 快速入门(全面)
Hive 快速入门(全面) 2018-07-30 16:11:56 琅琊山二当家 阅读数 4343更多 分类专栏: hadoop 大数据 转载: https://www.codercto.com/ ...
- ASP.NET请求过程-从源码角度研究MVC路由、Handler、控制器
路由常用对象 RouteBase 用作表示 ASP.NET 路由的所有类的基类. 就是路由的一个基础抽象类. // // 摘要: // 用作表示 ASP.NET 路由的所有类的基类. [ ...
- python 并发的开端
目录 网络并发 进程的基础 2.操作系统 操作系统的发展史 多道技术 第二代 1955~1965 磁带存储--批处理系统 第三代集成电路,多道程序系统(1955~1965) 进程的理论(重点) 2.操 ...
- 在做爬虫或者自动化测试时新打开一个新标签页,必须使用windows切换
在做爬虫或者自动化测试时,有时会打开一个新的标签页或者新的窗口,直接使用xpath定位元素会发现找不到元素,在firefox中定位了元素还是找不到, 经过多次发现,在眼睛视野内看到这个窗口是在最前面, ...
- Java安装和环境配置
Java安装和环境配置 从事Java开发第一关就是安装JAVA环境. 我们要安装JDK, 全称Java开发全套. 其中包含了JRE(运行时环境), 如果你打游戏的时候可能会提示你缺少JRE. 我们要做 ...