从头学pytorch(二十二):全连接网络dense net
DenseNet
论文传送门,这篇论文是CVPR 2017的最佳论文.
resnet一文里说了,resnet是具有里程碑意义的.densenet就是受resnet的启发提出的模型.
resnet中是把不同层的feature map相应元素的值直接相加.而densenet是将channel维上的feature map直接concat在一起,从而实现了feature的复用.如下所示:
注意,是连接dense block内输出层前面所有层的输出,不是只有输出层的前一层
网络结构
首先实现DenseBlock
先解释几个名词
bottleneck layer
即上图中红圈的1x1卷积核.主要目的是对输入在channel维度做降维.减少运算量.
卷积核的数量为4k,k为该layer输出的feature map的数量(也就是3x3卷积核的数量)growth rate
即上图中黑圈处3x3卷积核的数量.假设3x3卷积核的数量为k,则每个这种3x3卷积后,都得到一个channel=k的输出.假如一个denseblock有m组这种结构,输入的channel为n的话,则做完一次连接操作后得到的输出的channel为n + k + k +...+k = n+m*k.所以又叫做growth rate.conv
论文里的conv指的是BN-ReLU-Conv
实现DenseBlock
DenseLayer
class DenseLayer(nn.Module):
def __init__(self,in_channels,bottleneck_size,growth_rate):
super(DenseLayer,self).__init__()
count_of_1x1 = bottleneck_size
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv1x1 = nn.Conv2d(in_channels,count_of_1x1,kernel_size=1)
self.bn2 = nn.BatchNorm2d(count_of_1x1)
self.relu2 = nn.ReLU(inplace=True)
self.conv3x3 = nn.Conv2d(count_of_1x1,growth_rate,kernel_size=3,padding=1)
def forward(self,*prev_features):
# for f in prev_features:
# print(f.shape)
input = torch.cat(prev_features,dim=1)
# print(input.device,input.shape)
# for param in self.bn1.parameters():
# print(param.device)
# print(list())
bottleneck_output = self.conv1x1(self.relu1(self.bn1(input)))
out = self.conv3x3(self.relu2(self.bn2(bottleneck_output)))
return out
首先是1x1卷积,然后是3x3卷积.3x3卷积核的数量即growth_rate,bottleneck_size即1x1卷积核数量.论文里是bottleneck_size=4xgrowth_rate的关系. 注意forward函数的实现
def forward(self,*prev_features):
# for f in prev_features:
# print(f.shape)
input = torch.cat(prev_features,dim=1)
# print(input.device,input.shape)
# for param in self.bn1.parameters():
# print(param.device)
# print(list())
bottleneck_output = self.conv1x1(self.relu1(self.bn1(input)))
out = self.conv3x3(self.relu2(self.bn2(bottleneck_output)))
return out
我们传进来的是一个元祖,其含义是[block的输入,layer1输出,layer2输出,...].前面说过了,一个dense block内的每一个layer的输入是前面所有layer的输出和该block的输入在channel维度上的连接.这样就使得不同layer的feature map得到了充分的利用.
tips:
函数参数带*表示可以传入任意多的参数,这些参数被组织成元祖的形式,比如
## var-positional parameter
## 定义的时候,我们需要添加单个星号作为前缀
def func(arg1, arg2, *args):
print arg1, arg2, args
## 调用的时候,前面两个必须在前面
## 前两个参数是位置或关键字参数的形式
## 所以你可以使用这种参数的任一合法的传递方法
func("hello", "Tuple, values is:", 2, 3, 3, 4)
## Output:
## hello Tuple, values is: (2, 3, 3, 4)
## 多余的参数将自动被放入元组中提供给函数使用
## 如果你需要传递元组给函数
## 你需要在传递的过程中添加*号
## 请看下面例子中的输出差异:
func("hello", "Tuple, values is:", (2, 3, 3, 4))
## Output:
## hello Tuple, values is: ((2, 3, 3, 4),)
func("hello", "Tuple, values is:", *(2, 3, 3, 4))
## Output:
## hello Tuple, values is: (2, 3, 3, 4)
DenseBlock
class DenseBlock(nn.Module):
def __init__(self,in_channels,layer_counts,growth_rate):
super(DenseBlock,self).__init__()
self.layer_counts = layer_counts
self.layers = []
for i in range(layer_counts):
curr_input_channel = in_channels + i*growth_rate
bottleneck_size = 4*growth_rate #论文里设置的1x1卷积核是3x3卷积核的4倍.
layer = DenseLayer(curr_input_channel,bottleneck_size,growth_rate).cuda()
self.layers.append(layer)
def forward(self,init_features):
features = [init_features]
for layer in self.layers:
layer_out = layer(*features) #注意参数是*features不是features
features.append(layer_out)
return torch.cat(features, 1)
一个Dense Block由多个Layer组成.这里注意forward的实现,init_features即该block的输入,然后每个layer都会得到一个输出.第n个layer的输入由输入和前n-1个layer的输出在channel维度上连接组成.
最后,该block的输出为各个layer的输出为输入以及各个layer的输出在channel维度上连接而成.
TransitionLayer
很显然,dense block的计算方式会使得channel维度过大,所以每一个dense block之后要通过1x1卷积在channel维度降维.
class TransitionLayer(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(TransitionLayer, self).__init__()
self.add_module('norm', nn.BatchNorm2d(in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
Dense Net
dense net的基本组件我们已经实现了.下面就可以实现dense net了.
class DenseNet(nn.Module):
def __init__(self,in_channels,num_classes,block_config):
super(DenseNet,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.dense_block_layers = nn.Sequential()
block_in_channels = in_channels
growth_rate = 32
for i,layers_counts in enumerate(block_config):
block = DenseBlock(in_channels=block_in_channels,layer_counts=layers_counts,growth_rate=growth_rate)
self.dense_block_layers.add_module('block%d' % (i+1),block)
block_out_channels = block_in_channels + layers_counts*growth_rate
transition = TransitionLayer(block_out_channels,block_out_channels//2)
if i != len(block_config): #最后一个dense block后没有transition layer
self.dense_block_layers.add_module('transition%d' % (i+1),transition)
block_in_channels = block_out_channels // 2 #更新下一个dense block的in_channels
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1,1))
self.fc = nn.Linear(block_in_channels,num_classes)
def forward(self,x):
out = self.conv1(x)
out = self.pool1(x)
for layer in self.dense_block_layers:
out = layer(out)
# print(out.shape)
out = self.avg_pool(out)
out = torch.flatten(out,start_dim=1) #相当于out = out.view((x.shape[0],-1))
out = self.fc(out)
return out
首先和resnet一样,首先是7x7卷积接3x3,stride=2的最大池化,然后就是不断地dense block + tansition.得到feature map以后用全局平均池化得到n个feature.然后给全连接层做分类使用.
可以用
X=torch.randn(1,3,224,224).cuda()
block_config = [6,12,24,16]
net = DenseNet(3,10,block_config)
net = net.cuda()
out = net(X)
print(out.shape)
测试一下,输出如下,可以看出feature map的变化情况.最终得到508x7x7的feature map.全局平均池化后,得到508个特征,通过线性回归得到10个类别.
torch.Size([1, 195, 112, 112])
torch.Size([1, 97, 56, 56])
torch.Size([1, 481, 56, 56])
torch.Size([1, 240, 28, 28])
torch.Size([1, 1008, 28, 28])
torch.Size([1, 504, 14, 14])
torch.Size([1, 1016, 14, 14])
torch.Size([1, 508, 7, 7])
torch.Size([1, 10])
总结:
核心就是dense block内每一个layer都复用了之前的layer得到的feature map,因为底层细节的feature被复用,所以使得模型的特征提取能力更强. 当然坏处就是计算量大,显存消耗大.
从头学pytorch(二十二):全连接网络dense net的更多相关文章
- 智课雅思词汇---二十二、-al即是名词性后缀又是形容词后缀
智课雅思词汇---二十二.-al即是名词性后缀又是形容词后缀 一.总结 一句话总结: 后缀:-al ②[名词后缀] 1.构成抽象名词,表示行为.状况.事情 refusal 拒绝 proposal 提议 ...
- [分享] IT天空的二十二条军规
Una 发表于 2014-9-19 20:25:06 https://www.itsk.com/thread-335975-1-1.html IT天空的二十二条军规 第一条.你不是什么都会,也不是什么 ...
- Bootstrap <基础二十二>超大屏幕(Jumbotron)
Bootstrap 支持的另一个特性,超大屏幕(Jumbotron).顾名思义该组件可以增加标题的大小,并为登陆页面内容添加更多的外边距(margin).使用超大屏幕(Jumbotron)的步骤如下: ...
- Web 前端开发精华文章推荐(HTML5、CSS3、jQuery)【系列二十二】
<Web 前端开发精华文章推荐>2014年第一期(总第二十二期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML ...
- 二十二、OGNL的一些其他操作
二十二.OGNL的一些其他操作 投影 ?判断满足条件 动作类代码: ^ $ public class Demo2Action extends ActionSupport { public ...
- WCF技术剖析之二十二: 深入剖析WCF底层异常处理框架实现原理[中篇]
原文:WCF技术剖析之二十二: 深入剖析WCF底层异常处理框架实现原理[中篇] 在[上篇]中,我们分别站在消息交换和编程的角度介绍了SOAP Fault和FaultException异常.在服务执行过 ...
- VMware vSphere 服务器虚拟化之二十二桌面虚拟化之创建View Composer链接克隆的虚拟桌面池
VMware vSphere 服务器虚拟化之二十二桌面虚拟化之创建View Composer链接克隆的虚拟桌面池 在上一节我们创建了完整克隆的自动专有桌面池,在创建过程比较缓慢,这次我们将学习创建Vi ...
- Bootstrap入门(二十二)组件16:列表组
Bootstrap入门(二十二)组件16:列表组 列表组是灵活又强大的组件,不仅能用于显示一组简单的元素,还能用于复杂的定制的内容. 1.默认样式列表组 2.加入徽章 3.链接 4.禁用的列表组 5. ...
- JAVA之旅(二十二)——Map概述,子类对象特点,共性方法,keySet,entrySet,Map小练习
JAVA之旅(二十二)--Map概述,子类对象特点,共性方法,keySet,entrySet,Map小练习 继续坚持下去吧,各位骚年们! 事实上,我们的数据结构,只剩下这个Map的知识点了,平时开发中 ...
- 备忘录模式 Memento 快照模式 标记Token模式 行为型 设计模式(二十二)
备忘录模式 Memento 沿着脚印,走过你来时的路,回到原点. 苦海翻起爱恨 在世间难逃避命运 相亲竟不可接近 或我应该相信是缘份 一首<一生所爱>触动了多少 ...
随机推荐
- 百度DMA+小度App的蓝牙语音解决方案入局
前记 人机交互经历了三个阶段键鼠.触屏和语音交互.在国外,谷歌.亚马逊.苹果等巨头的竞争已经到达白热化状态:在国内,百度的DuerOS凭借着入局早,投入大,已经成为国内语音互交的一面旗帜.无论是从 ...
- mysql主从之配置验证
实验环境: master 192.168.132.121 主库 slave 192.168.132.122 从库 一 mysql主从复制的配置 1.1 mysql主库给从库复制的权限 mys ...
- docker容器内存占用过高(例如mysql)
简介 该文章适用于配置低,特别是内存低的服务器,在用容器部署服务时有可能会因为容器占用内存过高导致服务挂掉时参考解决(不是运行在容器里的话,也是可以修改mysql的配置文件限制内存占用) 最近用doc ...
- 洛谷$P2150\ [NOI2015]$寿司晚宴 $dp$
正解:$dp$ 解题报告: 传送门$QwQ$. 遇事不决写$dp$($bushi$.讲道理这题一看就感觉除了$dp$也没啥很好的算法能做了,于是考虑$dp$呗 先看部分分?$30pts$发现质因数个数 ...
- [Oracle]Oracle的闪回归档
Oracle的闪回归档 场景需求,由于管理数据库的一些核心表,在实施初期会有人为误删除的问题.Oracle 11gR2提供了闪回归档的特性,可以保证不用RMAN来恢复误删除的数据.实践如下: 1.创建 ...
- 端口扫描器--利用python的nmap模块
安装nmap模块挺麻烦的,搞了半天 不仅要安装pip install nmap 还要sudo apt install nmap 给出代码,没有设多线程,有点慢,注意端口的类型转换,搞了很久 #!/us ...
- 《图解机器学习-杉山将著》读书笔记---CH1
CH1 什么是机器学习 重点提炼 机器学习的种类: 常分为:监督学习.无监督学习.强化学习等 监督学习是学生从老师那获得知识,老师提供对错指示 无监督学习是在没有老师的情况下,学生自习 强化学习是在没 ...
- es6种for循环中let和var区别
let和var区别: for(var i=0;i<5;i++){ setTimeout(()=>{ console.log(i);//5个5 },100) } console.log(i) ...
- 【记】创建 VirtualBoxClient COM 对象失败. 应用程序将被中断
1. 在本地64位win7系统安装VirtualBox完,启动时提示错误 原因:兼容性造成的 按照下图显示修改VirtualBox快捷方式的兼容性 2. 启动虚拟机时,提示 点击弹出框的确定按钮后,接 ...
- ASP.Net MVC 引用动态 js 脚本
希望可以动态生成 js 发送给客户端使用. layout页引用: <script type="text/javascript" src="@Url.Action( ...