MegEngine亚线性显存优化
MegEngine亚线性显存优化
MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch size,进一步提升模型性能,稳定batchwise算子。使用MegEngine训练ResNet18/ResNet50,显存占用分别最高降低23%/40%;在更大的Bert模型上,降幅更是高达75%,而额外的计算开销几乎不变。
基于梯度检查点的亚线性显存优化方法[1]由于较高的计算/显存性价比受到关注。MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch size,进一步提升模型性能,稳定batchwise算子。使用MegEngine训练ResNet18/ResNet50,显存占用分别最高降低23%/40%;在更大的Bert模型上,降幅更是高达75%,而额外的计算开销几乎不变。该技术已在MegEngine开源,欢迎大家上手使用:https://github.com/MegEngine。
深度神经网络训练是一件复杂的事情,它体现为模型的时间复杂度和空间复杂度,分别对应着计算和内存;而训练时内存占用问题是漂浮在深度学习社区上空的一块乌云,如何拨云见日,最大降低神经网络训练的内存占用,是一个绕不开的课题。
GPU显卡等硬件为深度学习提供了必需的算力,但硬件自身有限的存储,限制了可训练模型的尺寸,尤其是大型深度网络,由此诞生出一系列相关技术,比如亚线性显存优化、梯度累加、混合精度训练、分布式训练,进行GPU显存优化。
其中,亚线性显存优化方法[1]由于较高的计算/显存性价比备受关注;旷视基于此,经过工程扩展和优化,发展出加强版的MegEngine亚线性显存优化技术,轻松把大模型甚至超大模型装进显存,也可以毫无压力使用大batch训练模型。
这里将围绕着深度学习框架MegEngine亚线性显存优化技术的工程实现和实验数据,从技术背景、原理、使用、展望等多个方面进行首次深入解读。
背景
在深度学习领域中,随着训练数据的增加,需要相应增加模型的尺寸和复杂度,进行模型「扩容」;而ResNet [2] 等技术的出现在算法层面扫清了训练深度模型的障碍。不断增加的数据和持续创新的算法给深度学习框架带来了新挑战,能否在模型训练时有效利用有限的计算存储资源,尤其是减少GPU显存占用,是评估深度学习框架性能的重要指标。
在计算存储资源一定的情况下,深度学习框架有几种降低显存占用的常用方法,其示例如下:
- 通过合适的梯度定义,让算子的梯度计算不再依赖于前向计算作为输入,从而in-place地完成算子的前向计算,比如Sigmoid、Relu等;
- 在生命周期没有重叠的算子之间共享显存;
- 通过额外的计算减少显存占用,比如利用梯度检查点重新计算中间结果的亚线性显存优化方法[1];
- 通过额外的数据传输减少显存占用,比如把暂时不用的数据从GPU交换到CPU,需要时再从CPU交换回来。
上述显存优化技术在MegEngine中皆有不同程度的实现,这里重点讨论基于梯度检查点的亚线性显存优化技术。
原理
一个神经网络模型所占用的显存空间大体分为两个方面:1)模型本身的参数,2)模型训练临时占用的空间,包括参数的梯度、特征图等。其中最大占比是 2)中以特征图形式存在的中间结果,比如,从示例[1]可知,根据实现的不同,从70%到90%以上的显存用来存储特征图。这里的训练过程又可分为前向计算,反向计算和优化三个方面,其中前向计算的中间结果最占显存,还有反向计算的梯度。第 1)方面模型自身的参数内存占用最小。 MegEngine加强版亚线性显存优化技术借鉴了[1]的方法,尤其适用于计算存储资源受限的情况,比如一张英伟达2080Ti,只有11G的显存;而更贵的Tesla V100,最大显存也只有32G。

图1:亚线性显存优化原理,其中 (b) 保存了Relu结果,实际中Relu结果可用in-place计算
图 1(a) 给出了卷积神经网络的基本单元,它由Conv-BN-Relu组成。可以看到,反向计算梯度的过程依赖于前向计算获取的中间结果,一个网络需要保存的中间结果与其大小成正比,即显存复杂度为O(n)。本质上,亚线性显存优化方法是以时间换空间,以计算换显存,如图 1(b) 所示,它的算法原理如下:
- 选取神经网络中k个检查点,从而把网络分成k个block,需要注意的是,初始输入也作为一个检查点;前向计算过程中只保存检查点处的中间结果;
- 反向计算梯度的过程中,首先从相应检查点出发,重新计算单个block需要的中间结果,然后计算block内部各个block的梯度;不同block的中间结果计算共享显存。
这种方法有着明显的优点,即大幅降低了模型的空间复杂度,同时缺点是增加了额外的计算:
- 显存占用从O(n)变成O(n/k)+ O(k),O(n/k)代表计算单个节点需要的显存,O(k)代表k个检查点需要的显存, 取k=sqrt(n),O(n/k)+ O(k)~O(sqrt(n)),可以看到显存占用从线性变成了亚线性;
- 因为在反向梯度的计算过程中需要从检查点恢复中间结果,整体需要额外执行一次前向计算。
工程
在[1]的基础上,MegEngine结合自身实践,做了工程扩展和优化,把亚线性显存优化方法扩展至任意的计算图,并结合其它常见的显存优化方法,发展出一套行之有效的加强版亚线性显存优化技术。亚线性优化方法采用简单的网格搜索(grid search)选择检查点,MegEngine在此基础上增加遗传算法,采用边界移动、块合并、块分裂等策略,实现更细粒度的优化,进一步降低了显存占用。
如图2所示,采用型号为2080Ti的GPU训练ResNet50,分别借助基准、亚线性、亚线性+遗传算法三种显存优化策略,对比了可使用的最大batch size。仅使用亚线性优化,batch size从133增至211,是基准的1.6x;而使用亚线性+遗传算法联合优化,batch size进一步增至262,较基准提升2x。

图2:三种显存优化方法优化batch size的对比:ResNet50
通过选定同一模型、给定batch size,可以更好地观察遗传算法优化显存占用的情况。如图3所示,随着迭代次数的增加,遗传算法逐渐收敛显存占用,并在第5次迭代之后达到一个较稳定的状态。

图3:遗传算法收敛示意图
此外,MegEngine亚线性优化技术通过工程改良,不再局限于简单的链状结构和同质计算节点, 可用于任意的计算图,计算节点也可异质,从而拓展了技术的适用场景;并可配合上述显存优化方法,进一步降低模型的显存占用。
实验
MegEngine基于亚线性显存技术开展了相关实验,这里固定batch size=64,在ResNet18和ResNet50两个模型上,考察模型训练时的显存占用和计算时间。如图4所示,相较于基准实现,使用MegEngine亚线性显存技术训练ResNet18时,显存占用降低32%, 计算时间增加24%;在较大的ReNet50上,显存占用降低40%,计算时间增加25%。同时经过理论分析可知,模型越大,亚线性显存优化的效果越明显,额外的计算时间则几乎不变。

图4:MegEngine亚线性优化技术实验显存/时间对比:ReNet18/ReNet50
在更大模型Bert上实验数据表明,借助MegEngine亚线性显存技术,显存占用最高降低75%,而计算时间仅增加23%,这与理论分析相一致。有兴趣的同学可前往MegEngine ModeHub试手更多模型实验:https://megengine.org.cn/model-hub/。
使用
MegEngine官网提供了亚线性显存优化技术的使用文档。当你的GPU显存有限,苦于无法训练较深、较大的神经网络模型,或者无法使用大batch进一步提升深度神经网络的性能,抑或想要使batchwise算子更加稳定,那么,MegEngine亚线性显存优化技术正是你需要的解决方案。
上手MegEngine亚线性优化技术非常便捷,无需手动设定梯度检查点,通过几个简单的参数,轻松控制遗传算法的搜索策略。具体使用时,在MegEngine静态图接口中调用SublinearMemoryConfig设置trace的参数sublinear_memory_config,即可打开亚线性显存优化:
from megengine.jit import trace, SublinearMemoryConfig
config = SublinearMemoryConfig()
@trace(symbolic=True, sublinear_memory_config=config)
def train_func(data, label, *, net, optimizer):
...
MegEngine在编译计算图和训练模型时,虽有少量的额外时间开销,但会显著缓解显存不足问题。下面以ResNet50为例,说明MegEngine可有效突破显存瓶颈,训练batch size从100最高增至200:
import os
from multiprocessing import Process
def train_resnet_demo(batch_size, enable_sublinear, genetic_nr_iter=0):
import megengine as mge
import megengine.functional as F
import megengine.hub as hub
import megengine.optimizer as optim
from megengine.jit import trace, SublinearMemoryConfig
import numpy as np
print(
"Run with batch_size={}, enable_sublinear={}, genetic_nr_iter={}".format(
batch_size, enable_sublinear, genetic_nr_iter
)
)
# 使用GPU运行这个例子
assert mge.is_cuda_available(), "Please run with GPU"
try:
# 从 megengine hub 中加载一个 resnet50 模型。
resnet = hub.load("megengine/models", "resnet50")
optimizer = optim.SGD(resnet.parameters(), lr=0.1,)
config = None
if enable_sublinear:
config = SublinearMemoryConfig(genetic_nr_iter=genetic_nr_iter)
@trace(symbolic=True, sublinear_memory_config=config)
def train_func(data, label, *, net, optimizer):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
optimizer.backward(loss)
resnet.train()
for i in range(10):
batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.int32)
optimizer.zero_grad()
train_func(batch_data, batch_label, net=resnet, optimizer=optimizer)
optimizer.step()
except:
print("Failed")
return
print("Sucess")
# 以下示例结果在2080Ti GPU运行得到,显存容量为 11 GB
# 不使用亚线性内存优化,允许的batch_size最大为 100 左右
p = Process(target=train_resnet_demo, args=(100, False))
p.start()
p.join()
# 报错显存不足
p = Process(target=train_resnet_demo, args=(200, False))
p.start()
p.join()
# 使用亚线性内存优化,允许的batch_size最大为 200 左右
p = Process(target=train_resnet_demo, args=(200, True, 20))
p.start()
p.join()
展望
如上所述,MegEngine的亚线性显存优化技术通过额外做一次前向计算,即可达到O(sqrt(n))的空间复杂度。如果允许做更多次的前向计算,对整个网络递归地调用亚线性显存算法,有望在时间复杂度为O(n log n)的情况下,达到 O(log n)的空间复杂度。
更进一步,MegEngine还将探索亚线性显存优化技术与数据并行/模型并行、混合精度训练的组合使用问题,以期获得更佳的集成效果。最后,在RNN以及GNN、Transformer等其他类型网络上的使用问题,也是MegEngine未来的一个探索方向。
欢迎访问
- MegEngine GitHub(欢迎star):https://github.com/MegEngine
- MegEngine官网:https://megengine.org.cn
- MegEngine ModelHub:https://megengine.org.cn/model-hub/
参考文献
- Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.
- He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
MegEngine亚线性显存优化的更多相关文章
- [Pytorch]深度模型的显存计算以及优化
原文链接:https://oldpan.me/archives/how-to-calculate-gpu-memory 前言 亲,显存炸了,你的显卡快冒烟了! torch.FatalError: cu ...
- Cpu Gpu 内存 显存 数据流
[精]从CPU架构和技术的演变看GPU未来发展 http://www.pcpop.com/doc/0/521/521832_all.shtml 显存与纹理内存详解 http://blog.csdn.n ...
- OpenGL8-直接分配显存-极速绘制(Opengl1.5版本才有)
视频教程请关注 http://edu.csdn.net/lecturer/lecturer_detail?lecturer_id=440 /** * 这个例子介绍如何使用显卡内存进行绘制 下载地址 : ...
- [自制操作系统] 图形界面&VBE工具&MMIO显存&图形库/字库
本文记录了在JOS(或在任意OS)上实现图形界面的方法与一些图形库的实现. 本文中支持的新特性: 支持基本图形显示 支持中英文显示(中英文点阵字库) 相关:VBE VESA MMIO 点阵字库 Git ...
- OpenCL将数组从内存copy到显存
本来想对上一篇博客做优化,优化效果不明显.但知识点还是要记一下. 初衷是想把上一篇博客中定义域的计算搬到CPU来计算,因为定义域的计算对于每一个kernel都是一样的,所以直接读取应该是可以进一步减小 ...
- 显卡、显卡驱动、显存、GPU、CUDA、cuDNN
显卡 Video card,Graphics card,又叫显示接口卡,是一个硬件概念(相似的还有网卡),执行计算机到显示设备的数模信号转换任务,安装在计算机的主板上,将计算机的数字信号转换成模拟 ...
- 深度学习中GPU和显存分析
刚入门深度学习时,没有显存的概念,后来在实验中才渐渐建立了这个意识. 下面这篇文章很好的对GPU和显存总结了一番,于是我转载了过来. 作者:陈云 链接:https://zhuanlan.zhihu. ...
- 我的Keras使用总结(5)——Keras指定显卡且限制显存用量,常见函数的用法及其习题练习
Keras 是一个高层神经网络API,Keras是由纯Python编写而成并基于TensorFlow,Theano以及CNTK后端.Keras为支持快速实验而生,能够将我们的idea迅速转换为结果.好 ...
- 【算法学习笔记】Meissel-Lehmer 算法 (亚线性时间找出素数个数)
「Meissel-Lehmer 算法」是一种能在亚线性时间复杂度内求出 \(1\sim n\) 内质数个数的一种算法. 在看素数相关论文时发现了这个算法,论文链接:Here. 算法的细节来自 OI w ...
随机推荐
- WindowsPE 第五章 导出表编程-1(枚举导出表)
导出表编程-1-枚举导出表 开始前先回忆一下导出表: 1.枚举dll函数的导出函数名字: 思路: (1)加载dll到内存里. (2)获取PE头,逐步找到可选头部. (3)然后找到里面的第一个结构(导出 ...
- metasploit console 命令解释
折腾几天,总算是在虚拟机中安装好了4.2版本的metasploit,能够成功打开console,这里将metasploit console的指令用中文翻一下: 原文及翻译: back Move bac ...
- 【SpringMVC】添加操作时返回400
本博客老魏原创,如需转载请留言 问题描述: springmvc向数据库添加新的记录时,发生400错误,控制台没有抛出异常. 问题原因: 视图中的提交数据的某一个字段不不匹配导致. 解决方法: 不要怀疑 ...
- python中的的异步IO
asyncio 是干什么的? 异步网络操作 并发 协程 python3.0 时代,标准库里的异步网络模块:select(非常底层) python3.0时代,第三方异步网络库:Tornado pytho ...
- 006-Java的访问权限控制符和包导入机制
目录 一.Java的访问权限控制符 一.访问控制符的作用 二.访问控制符的分类 二.Java的包导入机制 一.为什么要使用package? 二.package怎么用? 三.对于带有package的ja ...
- Asp.NetCore Web开发之ADO.Net
Asp.NetCore可以说是.Net平台开发网站的一大利器,最近的一大段时间,就要跟大家分享,如何使用这一利器开发网站项目. 要学习网站开发,首先要学习如何使用ADO.Net进行数据库数据的增删改 ...
- 【Azure Developer】使用Microsoft Graph API 批量创建用户,先后遇见的三个错误及解决办法
问题描述 在先前的一篇博文中,介绍了如何使用Microsoft Graph API来创建Azure AD用户(博文参考:[Azure Developer]使用Microsoft Graph API 如 ...
- JavaScrip条件表达式优化
目录 1,前言 2,多条件if语句优化 3,参数默认值 4,Switch语句优化 1,前言 今早看了一篇文章<JavaScrip实现:如何写出漂亮的条件表达式>,原创于:华为云开发者社区, ...
- 二分查找确定lower_bound和upper_bound
lower_bound当target存在时, 返回它出现的第一个位置,如果不存在,则返回这样一个下标i:在此处插入target后,序列仍然有序. 代码如下: int lower_bound(int* ...
- CSS层叠性
比较id,类,标签的数量 谁多就谁在上面 255个类的权重等于一个id 当权重一样时,以后设置的为准 通过继承而来的,权重为0 !important (设置权重无限大)可以影响权重,但只能影响选中的, ...