PSPnet模型结构的实现代码
该模块融合了4种不同金字塔尺度的特征,第一行红色是最粗糙的特征–全局池化生成单个bin输出,后面三行是不同尺度的池化特征。
为了保证全局特征的权重,如果金字塔共有N个级别,则在每个级别后使用1×1 1×11×1的卷积将对于级别通道降为原本的1/N。再通过双线性插值获得未池化前的大小,最终concat到一起。
1 import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models from utils import initialize_weights
from utils.misc import Conv2dDeformable
from .config import res101_path //金字塔模块,将从前面卷积结构提取的特征分别进行不同的池化操作,得到不同感受野以及全局语境信息(或者叫做不同层级的信息)
class _PyramidPoolingModule(nn.Module):
def __init__(self, in_dim, reduction_dim, setting):
super(_PyramidPoolingModule, self).__init__()
self.features = []
for s in setting: //对应不同的池化操作,单个bin,多个bin
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(s),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(reduction_dim, momentum=.95),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features) def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
out = torch.cat(out, 1)
return out //整个pspnet网络的结构
class PSPNet(nn.Module):
def __init__(self, num_classes, pretrained=True, use_aux=True):
super(PSPNet, self).__init__()
self.use_aux = use_aux
resnet = models.resnet101() //采用resnet101作为骨干模型,提取特征
if pretrained:
resnet.load_state_dict(torch.load(res101_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
//设置带洞卷积的参数(dilation),以及卷积的参数
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
//加入ppm模块,以及最后的连接层(卷积)
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
) if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)
# 初始化权重
initialize_weights(self.ppm, self.final) def forward(self, x):
x_size = x.size()
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.training and self.use_aux:
aux = self.aux_logits(x)
x = self.layer4(x)
x = self.ppm(x)
x = self.final(x)
if self.training and self.use_aux:
return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
return F.upsample(x, x_size[2:], mode='bilinear')
PSPnet模型结构的实现代码的更多相关文章
- 关于DPM(Deformable Part Model)算法中模型结构的解释
关于可变部件模型的描写叙述在作者[2010 PAMI]Object Detection with Discriminatively Trained Part Based Models的论文中已经有说明 ...
- 第六章 Odoo 12开发之模型 - 结构化应用数据
在本系列文章第三篇Odoo 12 开发之创建第一个 Odoo 应用中,我们概览了创建 Odoo 应用所需的所有组件.本文及接下来的一篇我们将深入到组成应用的每一层:模型层.视图层和业务逻辑层. 本文中 ...
- BS模式的模型结构详解
编号:1004时间:2016年4月12日16:59:17功能:BS模式的模型结构详解 URL:http://blog.csdn.net/icerock2000/article/details/4000 ...
- 卷积神经网络(CNN)模型结构
在前面我们讲述了DNN的模型与前向反向传播算法.而在DNN大类中,卷积神经网络(Convolutional Neural Networks,以下简称CNN)是最为成功的DNN特例之一.CNN广泛的应用 ...
- 随想:目标识别中,自适应样本均衡设计,自适应模型结构(参数可变自适应,模型结构自适应,数据类别or分布自适应)
在现在的机器学习中,很多人都在研究自适应的参数,不需要人工调参,但是仅仅是自动调参就不能根本上解决 ai识别准确度达不到实际生产的要求和落地困难的问题吗?结论可想而知.如果不改变参数,那就得从算法的结 ...
- 3. RNN神经网络-LSTM模型结构
1. RNN神经网络模型原理 2. RNN神经网络模型的不同结构 3. RNN神经网络-LSTM模型结构 1. 前言 之前我们对RNN模型做了总结.由于RNN也有梯度消失的问题,因此很难处理长序列的数 ...
- 漫谈四种神经网络序列解码模型【附示例代码】 glimpse attention
漫谈四种神经网络序列解码模型[附示例代码] http://jacoxu.com/encoder_decoder/ [视觉注意力的循环神经网络模型]http://blog.csdn.net/leo_xu ...
- MES系统的模型结构和主要功能(二)
上一节,我们主要说了Mes系统是什么,以及它的特点和难点,本节,再来讨论一下一个合格的MES系统的模型结构和基本功能. 现代工厂的快速发展,对MES系统提出了更高的要求,其必须满足范围广泛的任务要求, ...
- pytorch模型结构可视化,可显示每层的尺寸
最近在学习一些检测方面的网络,使用的是pytorch.模型结构可视化是学习网络的有用的部分,pytorch没有原生支持这个功能,需要找一些其他方式,下面总结几种方法(推荐用4). 1. torch . ...
随机推荐
- 运行yarn的时候提示 node不是内部或外部命令
背景:准备react native 搭建,装完nodejs npm 重启cmd,再次管理员运行即可!
- 接口和多态都为JAVA技术的核心。
类必须实现接口中的方法,否则其为一抽象类. 实现中接口和类相同. 接口中可不写public,但在子类中实现接口的过程中public不可省. (如果剩去public则在编译的时候提示出错:对象无法从 ...
- yii2 緩存
1.Yii框架的缓存 主要就是“memcache” 和 “cache”两种 Yii的自带缓存都继承CCache 类, 在使用上基本没有区别 2.使用方法 (1)在config配置文件main.php文 ...
- yii2的数据库读写分离配置
简介 数据库读写分离是在网站遇到性能瓶颈的时候最先考虑优化的步骤,那么yii2是如何做数据库读写分离的呢?本节教程来给大家普及一下yii2的数据库读写分离配置. 两个服务器的数据同步是读写分离的前提条 ...
- python开发工具
要用到爬虫,网上推荐crapy,自己在pycharm上没有安装成功,于是使用anaconda,但是在gui界面安装crapy总是失败,且没有提示信息.于是使用命令行的方式,提示PermissionEr ...
- 1.6 flask应用: 代码统计系统
2019-1-6 15:57:18 今天的是做了一个代码统计的demo 使用了数据库的连接池 参考连接 https://www.cnblogs.com/wupeiqi/articles/8184686 ...
- vue项目实战中的增、删、改、查
参考:https://blog.csdn.net/xr510002594/article/details/81665762?utm_source=blogxgwz0 https://blog.csdn ...
- Oracle课程档案,第十五天
restore:恢复数据文件 recover:写日志 1.redo(roll forward)重做 (前进) 2.undo(roll back) 撤销 (回滚) cp -r:删除一个目录 archiv ...
- SpringBoot定时任务说明
1. 定时任务实现方式 定时任务实现方式: Java自带的java.util.Timer类,这个类允许你调度一个java.util.TimerTask任务.使用这种方式可以让你的程序按照某一个频度执行 ...
- python全栈开发 * 13知识点汇总 * 180619
13 迭代器和⽣成器一.迭代器 1.以通过dir函数来查看类中定义好的所有⽅法 2.__iter__ 用来获取当前对象的迭代器 3.__next__ 获取可迭代对象的元素s="我爱吃火锅&q ...
