【小白学PyTorch】13 EfficientNet详解及PyTorch实现
参考目录:
文章来自微信公众号【机器学习炼丹术】。我是炼丹兄,如果有疑问或者想要和炼丹兄交流的可以加微信:cyx645016617.
efficientNet的论文原文链接: https://arxiv.org/pdf/1905.11946.pdf
模型扩展Model scaling一直以来都是提高卷积神经网络效果的重要方法。
比如说,ResNet可以增加层数从ResNet18扩展到ResNet200。这次,我们要介绍的是最新的网络结构——EfficientNet,就是一种标准化的模型扩展结果,通过下面的图,我们可以i只管的体会到EfficientNet b0-b7在ImageNet上的效果:对于ImageNet历史上的各种网络而言,可以说EfficientNet在效果上实现了碾压
1 EfficientNet
1.1 概述
一般我们在扩展网络的时候,一般通过调成输入图像的大小、网络的深度和宽度(卷积通道数,也就是channel数)。在EfficientNet之前,没有研究工作只是针对这三个维度中的某一个维度进行调整,因为没钱啊!!有限的计算能力,很少有研究对这三个维度进行综合调整的。
EfficientNet的设想就是能否设计一个标准化的卷积网络扩展方法,既可以实现较高的准确率,又可以充分的节省算力资源。因而问题可以描述成,如何平衡分辨率、深度和宽度这三个维度,来实现拘拿及网络在效率和准确率上的优化
EfficientNet给出的解决方案是提出了这个模型复合缩放方法 (compound scaling methed)
图a是一个基线网络,也就是我们所说的baseline,图b,c,d三个网络分别对该基线网络的宽度、深度、和输入分辨率进行了扩展,而最右边的e图,就是EfficientNet的主要思想,综合宽度、深度和分辨率对网络进行符合扩展。
1.2 把扩展问题用数学来描述
首先,我们把整个卷积网络称为N,他的第i个卷积层可以看作下面的函数映射:
Yi是输出张量,Xi是输入张量,假设这个Xi的维度是<Hi,Wi,Ci>(这里省略了Batch的维度),那么这个整个卷积网络N,是由k个卷积层组成的,因此可以表示为:
通常情况,一个神经网络会有多个相同的卷积层存在,因此,我们称多个结构相同的卷积层为一个stage。举个例子:ResNet可以分为5个stage,每一个stage中的卷积层结构相同(除了第一层为降采样层),前四个stage都是baseblock,第五个stage是fc层。不太理解的可以看这个:【从零学习PyTorch】 如何残差网络resnet作为pre-model +代码讲解+残差网络resnet是个啥
总之,我们以stage为单位,将上面的卷积网络N改成为:
其中,下表1...s表示stage的讯号,Fi表示对第i层的卷积运算,Li的意思是Fi在第i个stage中有Li个一样结构的卷积层。<Hi, Wi, Ci>表示第i层输入的shape。
为了减小搜索空间,作者先固定了网络的基本结构,而只改变上面公式中的三个缩放维度。还记得之前我们提高的分辨率,宽度,深度吗?
- Li就是深度,Li越大重复的卷积层越多,网络越深;
- Ci就是channel数目,也就是网络的宽度
- Hi和Wi就是图片的分辨率
就算如此,这也有三个参数要调整,搜索空间也是非常的大,因此EfficientNet的设想是一个卷积网络所有的卷积层必须通过相同的比例常数进行统一扩展,这句话的意思是,三个参数乘上常熟倍率。所以个一个模型的扩展问题,就用数学语言描述为:
其中,d、w和r分别表示网络深度、宽度和分辨率的倍率。这个算式表现为在给定计算内存和效率的约束下,如何优化参数d、w和r来实现最好的模型准确率。
1.3 实验内容
上面问题的难点在于,三个倍率之间是由内在联系的,比如更高分辨率的图片就需要更深的网络来增大感受野的捕捉特征。因此作者做了两个实验(实际应该是做了很多的实验)来说明:
(1) 第一个实验,对三个维度固定了两个,只方法其中一个,得到的结果如下:
从左到右分别是只放大了网络宽度(width,w为放大倍率)、网络深度(depth,d为放大倍率)和图像分辨率(resolution, r为放大倍率)。我们可以看到,单个维度的放大最高精度只有80左右,本次实验,作者得出一个管带你:三个维度中任一维度的放大都可以带来精度的提升,但是随着倍率的越来越大,提升越来越小。
(2)于是作者做了第二个实验,尝试在不同的d,r组合下变动w,得到下图:
从实验结果来看,最高精度相比之前已经有所提升,突破了80大关。而且组合不同,效果不同。作者又得到了一个观点:得到了更高的精度以及效率的关键是平衡网络的宽度,网络深度,网络分辨率三个维度的缩放倍率
1.4 compound scaling method
这时候作者提出了这个方法
EfficientNet的规范化复合调参方法使用了一个复合系数\(\phi\),来对三个参数进行符合调整:
其中的\(\alpha, \beta, \gamma\)都是常数,可以通过网格搜索获得。复合系数通过人工调节。考虑到如果网络深度翻番那么对应的计算量翻番,网络宽度和图像分辨率翻番对应的计算量会翻4番,卷积操作的计算量与\(d,w^2 ,r^2\)成正比,。在这个约束下,网络的计算量大约是之前的\(2^\phi\)倍
以上就是EfficientNet的复合扩展的方式,但是这仅仅是一种模型扩展方式,我们还没有讲到EfficientNet到底是一个什么样的网络。
1.5 EfficientNet的基线模型
EfficientNet使用了MobileNet V2中的MBCConv作为模型的主干网络,同时也是用了SENet中的squeeze and excitation方法对网络结构进行了优化。 MBCConv是mobileNet中的基本结构,关于什么是MBCconv在百度上很少有解释,通过阅读论文和Google这里有一个比较好的解释:
The MBConv block is nothing fancy but an Inverted Residual Block (used in MobileNetV2) with a Squeeze and Excite block injected sometimes.
MBCconv就是一个MobileNet的倒残差模块,但是这个模块中还封装了Squeeze and Excite的方法。
总之呢,综合了MBConv和squeeze and excitation方法的EfficientNet-B0的网络结构如下表所示:
对于EfficientNet-B0这样的一个基线网络,如何使用复合扩展发对该网络进行扩展呢?这里主要是分两步走:还记得这个规划问题吗?
(1)第一步,先将复合系数\(\phi\)固定为1,先假设有两倍以上的计算资源可以用,然后对\(\alpha, \beta, \gamma\)进行网络搜索。对于EfficientNet-B0网络,在约束条件为
\]
时,\(\alpha, \beta, \gamma\)分别取1.2,1.1和1.15时效果最佳。第二步是固定\(\alpha, \beta, \gamma\),通过复合调整公式对基线网络进行扩展,得到B1到B7网络。于是就有了开头的这一张图片,EfficientNet在ImageNet上的效果碾压,而且模型规模比此前的GPipe小了8.4倍。
普通人来训练和扩展EfficientNet实在过于昂贵,所以对于我们来说,最好的方法就是迁移学习,下面我们来看如何用PyTorch来做迁移学习。
2 PyTorch实现
之前也提到了,在torchvision中并没有加入efficientNet所以这里我们使用某一位大佬贡献的API。有一个这样的文件Efficient_PyTorch,里面存放了b0到b8的预训练模型存储文件,我们将会调用这个API。因为这里我们没有直接使用pip进行安装,所以需要将这个库函数设置成系统路径。Pycharm中很多朋友会踩着个坑,不知道如何设置成系统路径:
点击Sources Root之后,就可以直接import了。
整个代码非常少,因为都写成API接口了嘛:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b0')
print(model)
打印的模型可以看,我加了详细的注解(快夸我):
整个b0的结构和论文中的结构相同:
从上图中可以知道,总共有16个MBConv模块;在第16个时候的输出通道为320个通道;
从运行结果来看,结构相同。总之这就是EfficientNet的结构,原理和调用方式。
【小白学PyTorch】13 EfficientNet详解及PyTorch实现的更多相关文章
- 【小白学PyTorch】11 MobileNet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...
- Pytorch autograd,backward详解
平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...
- 【小白学PyTorch】10 pytorch常见运算详解
参考目录: 目录 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 近似值运算 7 剪裁运算 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些 ...
- 【小白学PyTorch】12 SENet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,有什么问题都可以来找我交流,近期建立了微信交流群,也在朋友圈抽奖赠书十多本了.我的微信是cyx645016617,欢迎各位朋友. 参考目录: @ 目录 ...
- 在CentOS上编译安装MySQL 5.7.13步骤详解
MySQL 5.7主要特性 更好的性能 对于多核CPU.固态硬盘.锁有着更好的优化,每秒100W QPS已不再是MySQL的追求,下个版本能否上200W QPS才是用户更关心的. 更好的InnoDB存 ...
- 一起学SpringMVC之RequestMapping详解
本文以一个简单的小例子,简述SpringMVC开发中RequestMapping的相关应用,仅供学习分享使用,如有不足之处,还请指正. 什么是RequestMapping? RequestMappin ...
- Pytorch数据读取详解
原文:http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler ...
- 【Linux】一步一步学Linux——Linux系统目录详解(09)
目录 00. 目录 01. 文件系统介绍 02. 常用目录介绍 03. /etc目录文件 04. /dev目录文件 05. /usr目录文件 06. /var目录文件 07. /proc 08. 比较 ...
- Appium自动化(13) - 详解 Keyboard 类里的方法和源码分析
如果你还想从头学起Appium,可以看看这个系列的文章哦! https://www.cnblogs.com/poloyy/category/1693896.html 前言 Keyboard 类在 a ...
随机推荐
- 基于 abp vNext 微服务开发的敏捷应用构建平台 - 设计构想
许多中小企业的管理模式都是在自身的发展过程中不断摸索,逐步建立起来的,每一家都有其独有的管理模式,而且随着企业的不断发展,管理模式也在不断变化中.企业在发展壮大的过程中离不开信息化系统的支撑,企业在构 ...
- Java多线程_同步工具CyclicBarrier
CyclicBarrier概念:CyclicBarrier是多线程中的一个同步工具,它允许一组线程互相等待,直到到达某个公共屏障点.形象点儿说,CyclicBarrier就是一个屏障,要求这一组线程中 ...
- 尝试写一写4gl与4fd
4gl DATABASE ds GLOBALS "../../config/top.global" DEFINE g_curs_index LIKE t ...
- 洛谷P3817 小A的糖果 贪心思想
一直觉得洛谷的背景故事很....直接题解吧 #include <bits/stdc++.h> //万能头文件 using namespace std; int a[100002]; // ...
- JAVA集合类简要笔记 - 内部类 包装类 Object类 String类 BigDecimal类 system类
常用类 内部类 成员内部类.静态内部类.局部内部类.匿名内部类 概念:在一个类的内部再定义一个完整的类 特点: 编译之后可生成独立的字节码文件 内部类可直接访问外部类私有成员,而不破坏封装 可为外部类 ...
- JetBrain破解
https://blog.csdn.net/u014044812/article/details/78727496 https://jetlicense.nss.im/ https://zhile.i ...
- 面试【JAVA基础】锁
1.锁状态 锁的状态只能升级不能降级. 无锁 没有锁对资源进行锁定,所有线程都能访问并修改同一个资源,但同时只有一个线程能修改成功.其他修改失败的线程会不断重试,直到修改成功,如CAS原理和应用是无锁 ...
- 不支持原子性的 Redis 事务也叫事务吗?
文章收录在 GitHub JavaKeeper ,N线互联网开发必备技能兵器谱 假设现在有这样一个业务,用户获取的某些数据来自第三方接口信息,为避免频繁请求第三方接口,我们往往会加一层缓存,缓存肯定要 ...
- Apache Jmter 压力测试教程
1.官网下载安装包,地址:http://jmeter.apache.org/download_jmeter. 2.下载得到解压包,双击解压. 3.点击/bin目录下面的jmeter.bat 启动软件 ...
- TP6.0多应用模式隐藏路由中的应用名
本文默认采用的是多应用模式 PHP技术群: 159789818 ThinkPHP技术群: 828567087 1. 多应用模式中隐藏路由中的应用名的三种方式 域名绑定应用 增加应用入口 入口文件绑定应 ...