DenseNet笔记
一、DenseNet的优点
- 减轻梯度消失问题
- 加强特征的传递
- 充分利用特征
- 减少了参数量
二、网络结构公式
对于每一个DenseBlock中的每一个层,

[x0,x1,…,xl-1]表示将0到l-1层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。而前面resnet是做值的相加,通道数是不变的。Hl包括BN,ReLU和3*3的卷积。
而在ResNet中的每一个残差块,

三、Growth Rate
指的是DenseBlock中每一个非线性变换Hl(BN,ReLU和3*3的卷积)的输出,这个输出与输入Concate.一个DenseBlock的输出=输入+Hl数×growth_rate。在要给DenseBlock中,Feature Map的size保持不变。
四、Bottleneck
这个组件位于DenseBlock中,当一个DenseBlock包含的非线性变换Hl较多时(如nHl=48),此时的grow rate为k=32,那么第48层的输入变成input+47×32,这是一个很大的数,如果不用bottleneck进行降维,那么计算量很大。
因此,使用4×k个1x1卷积进行降维。使得3×3线性变换的输入通道变成4×k。同时,bottleneck起到特征融合的效果。
五、Transition
这个组件位于DenseBlock之间,使用1×1卷积进行降维,降维后的通道数为input_channels*reduction. 参数reduction默认为0.5,后接池化层进行下采样,减小Feature Map 分辨率。
六、网络结构


七、代码实现(Pytorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math class Bottleneck(nn.Module):
def __init__(self,nChannels,growthRate):
super(Bottleneck,self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,interChannels,kernel_size=1,
stride=1,bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels,growthRate,kernel_size=3,
stride=1,padding=1,bias=False) def forward(self, *input):
#先进行BN(pytorch的BN已经包含了Scale),然后进行relu,conv1起到bottleneck的作用
out = self.conv1(F.relu(self.bn1(input)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat(input,out)
return out class SingleLayer(nn.Module):
def __init__(self,nChannels,growthRate):
super(SingleLayer,self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,growthRate,kernel_size=3,
padding=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = torch.cat(input,out)
return out class Transition(nn.Module):
def __int__(self,nChannels,nOutChannels):
super(Transition,self).__init__() self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,nOutChannels,kernel_size=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = F.avg_pool2d(out,2)
return out class DenseNet(nn.Module):
def __init__(self,growthRate,depth,reduction,nClasses,bottleneck):
super(DenseNet,self).__init__()
#DenseBlock中非线性变换模块的个数
nNoneLinears = (depth-4)//3
if bottleneck:
nNoneLinears //=2 nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3,nChannels,kernel_size=3,padding=1,bias=False)
self.denseblock1 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction)) #向下取整
self.transition1 = Transition(nChannels,nOutChannels) nChannels = nOutChannels
self.denseblock2 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.transition2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels
self.denseblock3 = self._make_dense(nChannels, growthRate, nNoneLinears, bottleneck)
nChannels += nNoneLinears * growthRate self.bn1 = nn.BatchNorm2d(nChannels)
self.fc = nn.Linear(nChannels,nClasses) #参数初始化
for m in self.modules():
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
m.weight.data.normal_(0,math.sqrt(2./n))
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
m.bias.data.zero_() def _make_dense(self,nChannels,growthRate,nDenseBlocks,bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels,growthRate))
else:
layers.append(SingleLayer(nChannels,growthRate))
nChannels+=growthRate
return nn.Sequential(*layers) def forward(self, *input):
out = self.conv1(input)
out = self.transition1(self.denseblock1(out))
out = self.transition2(self.denseblock2(out))
out = self.denseblock3(out)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)),8))
out = F.log_softmax(self.fc(out))
return out
DenseNet笔记的更多相关文章
- 论文笔记——DenseNet
<Densely Connected Convolutional Networks>阅读笔记 代码地址:https://github.com/liuzhuang13/DenseNet 首先 ...
- 论文笔记:CNN经典结构2(WideResNet,FractalNet,DenseNet,ResNeXt,DPN,SENet)
前言 在论文笔记:CNN经典结构1中主要讲了2012-2015年的一些经典CNN结构.本文主要讲解2016-2017年的一些经典CNN结构. CIFAR和SVHN上,DenseNet-BC优于ResN ...
- DenseNet 论文阅读笔记
Densely Connected Convolutional Networks 原文链接 摘要 研究表明,如果卷积网络在接近输入和接近输出地层之间包含较短地连接,那么,该网络可以显著地加深,变得更精 ...
- tensorflow学习笔记——DenseNet
完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和De ...
- 论文笔记系列-Neural Network Search :A Survey
论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- 转载:DenseNet算法详解
原文连接:http://blog.csdn.net/u014380165/article/details/75142664 参考连接:http://blog.csdn.net/u012938704/a ...
- Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。
如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...
- DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存
论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...
随机推荐
- c++11 继承控制:final和override
c++11 继承控制:final和override #define _CRT_SECURE_NO_WARNINGS #include <iostream> #include <str ...
- BZOJ4727 [POI2017]Turysta 【竞赛图哈密顿路径/回路】
题目链接 BZOJ4727 题解 前置芝士 1.竞赛图存在哈密顿路径 2.竞赛图存在哈密顿回路,当且仅当它是强联通的 所以我们将图缩点后,拓扑排序后一定是一条链,且之前的块内的点和之后块内的点的边一定 ...
- 解决Android SDK Manager更新时出现问题
使用SDK Manager更新时出现问题Failed to fetch URL https://dl-ssl.google.com/android/repository/repository-6.xm ...
- 解题:CQOI 2015 选数
题面 神仙题,不需要反演 首先上下界同时除以$k$,转换成取$n$个$gcd$为$1$的数的方案数,其中上界向下取整,下界向上取整 然后设$f[i]$表示选$n$个互不相同的数$gcd$为$i$的方案 ...
- 解题:JSOI 2007 重要的城市
题面 考虑一个点$x$,如果某两个点$u,v$间的所有最短路都经过$x$,那么$x$肯定是重要的.这个题$n$比较小,所以我们直接跑floyd,在过程中记录 当发生松弛时,我们具体讨论: 如果这个长度 ...
- MySQL基本了解与使用
MySQL的相关概念介绍 MySQL 为关系型数据库(Relational Database Management System), 这种所谓的"关系型"可以理解为"表格 ...
- BTC钱包对接流程
BTC钱包对接流程: 部署钱包节点 分析钱包的API 通过JSON-RPC访问钱包API 部署测试 1.部署钱包节点 虚拟币交易平台对接所有的虚拟币之前,都要在自己的服务器上部署一个钱包节点,首先要找 ...
- PostgreSQL 修改字段类型从int到bigint
由于现在pg的版本,修改int到bigint仍然需要rewrite表,会导致表阻塞,无法使用.但可以考虑其他方式来做.此问题是排查现网pg使用序列的情况时遇到的. 由于int的最大值只有21亿左右,而 ...
- 300. Longest Increasing Subsequence_算法有误
300. Longest Increasing Subsequence 300. Longest Increasing Subsequence Given an unsorted array of i ...
- shell实例浅谈之一产生随机数七种方法
一.问题 Shell下有时需要使用随机数,在此总结产生随机数的方法.计算机产生的的只是“伪随机数”,不会产生绝对的随机数(是一种理想随机数).伪随机数在大量重现时也并不一定保持唯一,但一个好的伪随机产 ...