DenseNet笔记
一、DenseNet的优点
- 减轻梯度消失问题
- 加强特征的传递
- 充分利用特征
- 减少了参数量
二、网络结构公式
对于每一个DenseBlock中的每一个层,
[x0,x1,…,xl-1]表示将0到l-1层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。而前面resnet是做值的相加,通道数是不变的。Hl包括BN,ReLU和3*3的卷积。
而在ResNet中的每一个残差块,
三、Growth Rate
指的是DenseBlock中每一个非线性变换Hl(BN,ReLU和3*3的卷积)的输出,这个输出与输入Concate.一个DenseBlock的输出=输入+Hl数×growth_rate。在要给DenseBlock中,Feature Map的size保持不变。
四、Bottleneck
这个组件位于DenseBlock中,当一个DenseBlock包含的非线性变换Hl较多时(如nHl=48),此时的grow rate为k=32,那么第48层的输入变成input+47×32,这是一个很大的数,如果不用bottleneck进行降维,那么计算量很大。
因此,使用4×k个1x1卷积进行降维。使得3×3线性变换的输入通道变成4×k。同时,bottleneck起到特征融合的效果。
五、Transition
这个组件位于DenseBlock之间,使用1×1卷积进行降维,降维后的通道数为input_channels*reduction. 参数reduction默认为0.5,后接池化层进行下采样,减小Feature Map 分辨率。
六、网络结构
七、代码实现(Pytorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math class Bottleneck(nn.Module):
def __init__(self,nChannels,growthRate):
super(Bottleneck,self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,interChannels,kernel_size=1,
stride=1,bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels,growthRate,kernel_size=3,
stride=1,padding=1,bias=False) def forward(self, *input):
#先进行BN(pytorch的BN已经包含了Scale),然后进行relu,conv1起到bottleneck的作用
out = self.conv1(F.relu(self.bn1(input)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat(input,out)
return out class SingleLayer(nn.Module):
def __init__(self,nChannels,growthRate):
super(SingleLayer,self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,growthRate,kernel_size=3,
padding=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = torch.cat(input,out)
return out class Transition(nn.Module):
def __int__(self,nChannels,nOutChannels):
super(Transition,self).__init__() self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels,nOutChannels,kernel_size=1,bias=False) def forward(self, *input):
out = self.conv1(F.relu(self.bn1(input)))
out = F.avg_pool2d(out,2)
return out class DenseNet(nn.Module):
def __init__(self,growthRate,depth,reduction,nClasses,bottleneck):
super(DenseNet,self).__init__()
#DenseBlock中非线性变换模块的个数
nNoneLinears = (depth-4)//3
if bottleneck:
nNoneLinears //=2 nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3,nChannels,kernel_size=3,padding=1,bias=False)
self.denseblock1 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction)) #向下取整
self.transition1 = Transition(nChannels,nOutChannels) nChannels = nOutChannels
self.denseblock2 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
nChannels += nNoneLinears*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.transition2 = Transition(nChannels, nOutChannels) nChannels = nOutChannels
self.denseblock3 = self._make_dense(nChannels, growthRate, nNoneLinears, bottleneck)
nChannels += nNoneLinears * growthRate self.bn1 = nn.BatchNorm2d(nChannels)
self.fc = nn.Linear(nChannels,nClasses) #参数初始化
for m in self.modules():
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
m.weight.data.normal_(0,math.sqrt(2./n))
elif isinstance(m,nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,nn.Linear):
m.bias.data.zero_() def _make_dense(self,nChannels,growthRate,nDenseBlocks,bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels,growthRate))
else:
layers.append(SingleLayer(nChannels,growthRate))
nChannels+=growthRate
return nn.Sequential(*layers) def forward(self, *input):
out = self.conv1(input)
out = self.transition1(self.denseblock1(out))
out = self.transition2(self.denseblock2(out))
out = self.denseblock3(out)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)),8))
out = F.log_softmax(self.fc(out))
return out
DenseNet笔记的更多相关文章
- 论文笔记——DenseNet
<Densely Connected Convolutional Networks>阅读笔记 代码地址:https://github.com/liuzhuang13/DenseNet 首先 ...
- 论文笔记:CNN经典结构2(WideResNet,FractalNet,DenseNet,ResNeXt,DPN,SENet)
前言 在论文笔记:CNN经典结构1中主要讲了2012-2015年的一些经典CNN结构.本文主要讲解2016-2017年的一些经典CNN结构. CIFAR和SVHN上,DenseNet-BC优于ResN ...
- DenseNet 论文阅读笔记
Densely Connected Convolutional Networks 原文链接 摘要 研究表明,如果卷积网络在接近输入和接近输出地层之间包含较短地连接,那么,该网络可以显著地加深,变得更精 ...
- tensorflow学习笔记——DenseNet
完整代码及其数据,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 这里结合网络的资料和De ...
- 论文笔记系列-Neural Network Search :A Survey
论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...
- 论文笔记:CNN经典结构1(AlexNet,ZFNet,OverFeat,VGG,GoogleNet,ResNet)
前言 本文主要介绍2012-2015年的一些经典CNN结构,从AlexNet,ZFNet,OverFeat到VGG,GoogleNetv1-v4,ResNetv1-v2. 在论文笔记:CNN经典结构2 ...
- 转载:DenseNet算法详解
原文连接:http://blog.csdn.net/u014380165/article/details/75142664 参考连接:http://blog.csdn.net/u012938704/a ...
- Dual Path Networks(DPN)——一种结合了ResNet和DenseNet优势的新型卷积网络结构。深度残差网络通过残差旁支通路再利用特征,但残差通道不善于探索新特征。密集连接网络通过密集连接通路探索新特征,但有高冗余度。
如何评价Dual Path Networks(DPN)? 论文链接:https://arxiv.org/pdf/1707.01629v1.pdf在ImagNet-1k数据集上,浅DPN超过了最好的Re ...
- DenseNet算法详解——思路就是highway,DneseNet在训练时十分消耗内存
论文笔记:Densely Connected Convolutional Networks(DenseNet模型详解) 2017年09月28日 11:58:49 阅读数:1814 [ 转载自http: ...
随机推荐
- Hive权限管理
最近遇到一个hive权限的问题,先简单记录一下,目前自己的理解不一定对,后续根据自己的理解程度更新 一.hive用户的概念 hive本身没有创建用户的命令,hive的用户就是Linux用户,若当前是用 ...
- 【BZOJ4197】【Noi2015】寿司晚宴
Description 为了庆祝 NOI 的成功开幕,主办方为大家准备了一场寿司晚宴.小 G 和小 W 作为参加 NOI 的选手,也被邀请参加了寿司晚宴. 在晚宴上,主办方为大家提供了 n−1 种不同 ...
- 书架 bookshelf
书架 bookshelf 题目描述 当Farmer John闲下来的时候,他喜欢坐下来读一本好书. 多年来,他已经收集了N本书 (1 <= N <= 100,000). 他想要建立一个多层 ...
- go语言操作mongodb
Install the MongoDB Go Driver The MongoDB Go Driver is made up of several packages. If you are just ...
- [转载]hzwer的bzoj题单
counter: 664BZOJ1601 BZOJ1003 BZOJ1002 BZOJ1192 BZOJ1303 BZOJ1270 BZOJ3039 BZOJ1191 BZOJ1059 BZOJ120 ...
- WinterCamp2017吃饭睡觉记
noip考完后励志好好学习进HE队然后Au,就这样每天勤奋刻苦发愤图强不知不觉就到冬令营了. 除了我之外的大佬们都是以上经历. 我呢……一个很爱浪的蒟蒻. 冬令营到了,伟大的CCF本着报一个录一个的原 ...
- R语言数据整理
基本操作 读入csv数据 data <- read.csv("D:/Project/180414/data.csv", header = TRUE) 写出csv数据 writ ...
- GC的时机
说到JVM,GC(垃圾回收)是非常重要的机制. 那么首先的问题是: GC在什么时候会发生? GC的触发包括两种情况:1.程序调用System.gc()的时候.2.系统自身决定是否需要GC. 系统进行G ...
- MongoDB 之 aggregate $group 巧妙运用
有这样一组数据: { "campaign_id": "A", "campaign_name": "A", "s ...
- Spyder之Object Inspector组件
Spyder之Object Inspector组件 最新版的Spyder已经把它修改为Help组件了. Quick access to documentation is a must for ever ...