一、DenseNet的优点

  • 减轻梯度消失问题
  • 加强特征的传递
  • 充分利用特征
  • 减少了参数量

二、网络结构公式

对于每一个DenseBlock中的每一个层,

[x0,x1,…,xl-1]表示将0到l-1层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。而前面resnet是做值的相加,通道数是不变的。Hl包括BN,ReLU和3*3的卷积。

而在ResNet中的每一个残差块,

三、Growth Rate

指的是DenseBlock中每一个非线性变换Hl(BN,ReLU和3*3的卷积)的输出,这个输出与输入Concate.一个DenseBlock的输出=输入+Hl数×growth_rate。在要给DenseBlock中,Feature Map的size保持不变。

四、Bottleneck

这个组件位于DenseBlock中,当一个DenseBlock包含的非线性变换Hl较多时(如nHl=48),此时的grow rate为k=32,那么第48层的输入变成input+47×32,这是一个很大的数,如果不用bottleneck进行降维,那么计算量很大。

因此,使用4×k个1x1卷积进行降维。使得3×3线性变换的输入通道变成4×k。同时,bottleneck起到特征融合的效果。

五、Transition

这个组件位于DenseBlock之间,使用1×1卷积进行降维,降维后的通道数为input_channels*reduction. 参数reduction默认为0.5,后接池化层进行下采样,减小Feature Map 分辨率。

六、网络结构

 

七、代码实现(Pytorch)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math class Bottleneck(nn.Module):
def __init__(self,nChannels,growthRate):
super(Bottleneck,self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,interChannels,kernel_size=1,
stride=1,bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels,growthRate,kernel_size=3,
stride=1,padding=1,bias=False) def forward(self, *input):
#先进行BN(pytorch的BN已经包含了Scale),然后进行relu,conv1起到bottleneck的作用
out = self.conv1(F.relu(self.bn1(input)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat(input,out)
return out class SingleLayer(nn.Module):
def __init__(self,nChannels,growthRate):
super(SingleLayer,self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,growthRate,kernel_size=3,
padding=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = torch.cat(input,out)
return out class Transition(nn.Module):
def __int__(self,nChannels,nOutChannels):
super(Transition,self).__init__() self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,nOutChannels,kernel_size=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = F.avg_pool2d(out,2)
return out class DenseNet(nn.Module):
def __init__(self,growthRate,depth,reduction,nClasses,bottleneck):
super(DenseNet,self).__init__()
#DenseBlock中非线性变换模块的个数
nNoneLinears = (depth-4)//3
if bottleneck:
nNoneLinears //=2 nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3,nChannels,kernel_size=3,padding=1,bias=False)
self.denseblock1 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction)) #向下取整
self.transition1 = Transition(nChannels,nOutChannels) nChannels = nOutChannels
self.denseblock2 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.transition2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels
self.denseblock3 = self._make_dense(nChannels, growthRate, nNoneLinears, bottleneck)
nChannels += nNoneLinears * growthRate self.bn1 = nn.BatchNorm2d(nChannels)
self.fc = nn.Linear(nChannels,nClasses) #参数初始化
for m in self.modules():
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
m.weight.data.normal_(0,math.sqrt(2./n))
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
m.bias.data.zero_() def _make_dense(self,nChannels,growthRate,nDenseBlocks,bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels,growthRate))
else:
layers.append(SingleLayer(nChannels,growthRate))
nChannels+=growthRate
return nn.Sequential(*layers) def forward(self, *input):
out = self.conv1(input)
out = self.transition1(self.denseblock1(out))
out = self.transition2(self.denseblock2(out))
out = self.denseblock3(out)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)),8))
out = F.log_softmax(self.fc(out))
return out

DenseNet笔记的更多相关文章

  1. 论文笔记——DenseNet

    <Densely Connected Convolutional Networks>阅读笔记 代码地址:https://github.com/liuzhuang13/DenseNet 首先 ...

  2. 论文笔记:CNN经典结构2(WideResNet,FractalNet,DenseNet,ResNeXt,DPN,SENet)

    前言 在论文笔记:CNN经典结构1中主要讲了2012-2015年的一些经典CNN结构.本文主要讲解2016-2017年的一些经典CNN结构. CIFAR和SVHN上,DenseNet-BC优于ResN ...

  3. DenseNet 论文阅读笔记

    Densely Connected Convolutional Networks 原文链接 摘要 研究表明,如果卷积网络在接近输入和接近输出地层之间包含较短地连接,那么,该网络可以显著地加深,变得更精 ...

  4. tensorflow学习笔记——DenseNet

    完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和De ...

  5. 论文笔记系列-Neural Network Search :A Survey

    论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...

  6. 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)

    前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...

  7. 转载:DenseNet算法详解

    原文连接:http://blog.csdn.net/u014380165/article/details/75142664 参考连接:http://blog.csdn.net/u012938704/a ...

  8. Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。

    如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...

  9. DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存

    论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...

随机推荐

  1. 【刷题】洛谷 P4209 学习小组

    题目描述 共有n个学生,m个学习小组,每个学生只愿意参加其中的一些学习小组,且一个学生最多参加k个学习小组.每个学生参加学习小组财务处都收一定的手续费,不同的学习小组有不同的手续费.若有a个学生参加第 ...

  2. Alpha 冲刺 —— 十分之四

    队名 火箭少男100 组长博客 林燊大哥 作业博客 Alpha 冲鸭鸭鸭鸭! 成员冲刺阶段情况 林燊(组长) 过去两天完成了哪些任务 协调各成员之间的工作 协助前后端接口的开发 测试项目运行的服务器环 ...

  3. 【ARC075F】Mirror

    Description ​ 给定正整数\(D\),求有多少个正整数\(N\),满足\(rev(N)=N+D\). ​ 其中\(rev(N)\)表示将\(N\)的十进制表示翻转来读得到的数(翻转后忽略前 ...

  4. 洛谷 P1850 换教室 解题报告

    P1850 换教室 题目描述 对于刚上大学的牛牛来说,他面临的第一个问题是如何根据实际情况申请合适的课程. 在可以选择的课程中,有\(2n\)节课程安排在\(n\)个时间段上.在第\(i(1≤i≤n) ...

  5. js 判断js函数,变量是否存在

    //是否存在指定函数 function isExitsFunction(funcName) {//这里的代码需要用try一下,因为当判断的函数是未定义时 浏览器会报错 try { if (typeof ...

  6. Android Intent 总结

    //打开指定网页Intent intent = new Intent(Intent.ACTION_VIEW);intent.setData(Uri.parse("http://www.goo ...

  7. js 弹出新页面,避免被浏览器、ad拦截的一种办法

    以绑定click弹窗的方式,改为普通的链接,即 a[target=_blank],在点击打开新窗口之前,修改其href. 绑定mousedown,鼠标点击执行完成前修改href. 绑定focus,保证 ...

  8. process.nextTick,Promise.then,setTimeout,setImmediate执行顺序

    1. 同步代码执行顺序优先级高于异步代码执行顺序优先级: 2. new Promise(fn)中的fn是同步执行: 3. process.nextTick()>Promise.then()> ...

  9. 2018.9.20 Educational Codeforces Round 51

    蒟蒻就切了四道水题,然后EF看着可写然而并不会,中间还WA了一次,我太菜了.jpg =.= A.Vasya And Password 一开始看着有点虚没敢立刻写,后来写完第二题发现可以暴力讨论,因为保 ...

  10. python之旅:模块与包

    一.模块介绍 前言:引用廖雪峰大神的,说的很好!!! 在计算机程序的开发过程中,随着程序代码越写越多,在一个文件里代码就会越来越长,越来越不容易维护. 为了编写可维护的代码,我们把很多函数分组,分别放 ...