为了更好的阅读体验,请点击这里

device_map

以下内容参考 Huggingface Accelerate文档:超大模型推理方法

在 HuggingFace 中有个重要的关键字是 device_map,它可以简单控制模型层部署在哪些硬件上。

设置参数 device_map="auto",Accelerate会自动检测在哪个设备放置模型的哪层参数(自动根据你的硬件资源分配模型参数)。其规则如下:

  • 首先充分利用GPU上的显存资源
  • 如果GPU上资源不够了,那么就将权重存储到内存
  • 如果内存还不够用了,将会使用内存映射的技术,将剩余的参数存储到硬盘上

设置参数 no_split_module_classes=["GPTJBlock"] 表示,模型中的 GPTJBlock 模块不会被切分移到不同的设备上。如果有些层包含残差操作,那么请做这样的设置。

可以通过 model.hf_device_map 来观察这个 device_map 的具体内容。

也可以自定义 device_map,也可以显示的设置每一层与设备的对应关系,然后使用如下代码加载 checkpoint。

model = load_checkpoint_and_dispatch(model, "sharded-gpt-j-6B", device_map=my_device_map)

设计 device_map

可以使用其他的 device map 映射方式,通过设置 device_map 参数(例如 "auto", "balanced", "balanced_low_0", "sequential"),或者手工设置这个 device map 字典。

读者可以操控模型在meta设备上的所有层(计算 device_map)。

当读者没有足够GPU显存来加载完整的模型(因为都会按照先占满GPU显存,再占满CPU/内存资源,最后占据硬盘的顺序来完成模型加载),上面所有的选项得到的层和设备对应结果将会相同。

当读者有足够的GPU资源来加载模型,那么上面4个选项得到的结果会有所不同。

  • "auto""balanced" 将会在所有的GPU上平衡切分模型,那么可以计算批尺寸大于 \(1\) 的输入
  • "balanced_low_0" 会在除了第一个GPU上的其它GPU上平衡划分模型,并且在第一个 GPU 上占据较少资源。这个选项符合需要在第一个 GPU 上进行额外操作的需求,例如需要在第一个 GPU 执行 generate 函数(迭代过程)。
  • "sequential" 按照GPU的顺序分配模型分片,从 GPU 0 开始,直到最后的 GPU(那么最后的 GPU 往往不会被占满,和 "balanced_low_0" 的区别就是第一个还是最后一个,以及非均衡填充)

这里 "auto""balanced" 会得到相同的结果,但是未来 "auto" 模式可能会被改变,主要是有可能发现更高效的分配策略。"balanced" 参数的功能则保持稳定。

还有需要注意的地方是,可以通过设置参数 max_memory 来限制每个GPU的显存使用数额(在函数 infer_auto_device_map() 函数设置)。当设置了参数 max_memory,需要构建一个包含 GPU 索引的字典(例如 \(0, 1\) 等),并且 cpu 这个 key 是指 CPU 离线加载的最大内存(可以设置占用内存,不然加载后电脑会卡)。字典的 value 值,可以是整数(单位是字节),也可以是字符串,例如 "10GiB""10GB"(表示最大占用10GB的显存)。

下面是一个例子。表示第 \(0\) 号和第 \(1\) 号 GPU 上最多提供 10GB 的显存、以及不超过 30GB 的内存给模型权重加载使用。

from accelerate import infer_auto_device_map

device_map = infer_auto_device_map(my_model, max_memory={0: "10GiB", 1: "10GiB", "cpu": "30GiB"})

当 PyTorch 加载模型时,他会先加载 CUDA 内核,这个就占据了 1-2GB 的显存(根据 GPU 的不同会略有区别)。因此能够使用的 GPU 显存要小于实际标定显存。可以使用代码 torch.ones(1).cuda() 来看看你的 GPU 上的 CUDA kernel 占用显存大小。

因此可以通过待 max_memory 参数的存储空间映射,来防止 out-of-memory 错误出现。

另外,如果有这样的需求,一些操作的输出需要在GPU上运行(例如 generate 函数来生成文本),或者将输入放到 GPU 上,那么在某个 GPU 上需要留一些显存(Accelerate 会将输出返回作为一些设备的输入)。又如果需要优化最大的批尺寸以及有很多的 GPU,那么可以在第一个的 GPU 留下足够的显存。例如在 8x80 A100 GPU 上运行 BLOOM-176B 的理想内存设置如下:

max_memory = {0: "30GIB", 1: "46GIB", 2: "46GIB", 3: "46GIB", 4: "46GIB", 5: "46GIB", 6: "46GIB", 7: "46GIB"}

除了 GPU 0 上留了足够的显存,在其他 \(7\) 个 GPU 上也留了将近 \(50\%\) 的显存。

如果自定义 device_map,那么字典中的 key 必须是模型中的模块名,且 value 是合法的设备索引(例如对于 GPU,是从 \(0\) 开始的整数)或 "cpu"(内存)或 "disk"(硬盘)。并且 key 要覆盖所有的模型模块名,满足这些规则,你可以随心所欲的定义你的 device map。例如你的模型有两个模块(就叫 block1block2 吧),每个模块包含三个线性层(称为 linear1linear2linear3 吧,还真是有顺序)。一个合法的 device map 如下:

device_map = {"block1": 0, "block2": 1}

另外一个合法的 device map 如下:

device_map = {"block1": 0, "block2.linear1": 0, "block2.linear2": 1, "block2.linear3": 1}

相反,下面的device map就不是合法,因为它的key没有覆盖模型的所有模块名。

device_map = {"block1": 0, "block2.linear1": 1, "block2.linear2": 1}

为了提升效率,最好是你的 device map 是按照 GPU 顺序,序贯的配置参数(例如不要在 GPU 0 上加载第一个全在,在 GPU 1 上再加载,然后权重又回到 GPU 0,交叉来回会降低模型推理速度)。主要是减少数据的 GPU 之间切换次数,提高效率。

一些注意事项

  • infer_auto_device_map() (或在load_checkpoint_and_dispatch()中设置device_map="auto")会最大化使用 GPU 显存和 CPU 内存。当 PyTorch 会高效的管理 GPU 显存(并且会释放不适用的显存),但是管理内存不怎么高效。因此,使用 auto 模式,在 CPU 内存管理上不是很高效。一个好的解决方法,就是将部分在内存上的模块移到硬盘上。
  • infer_auto_device_map() (或在load_checkpoint_and_dispatch()中设置device_map="auto")是按照 GPU、CPU 和硬盘的顺序分配模型模块(防止循环操作),因此如果你的第一个层需要的 GPU 显存空间大于 GPU 显存时,有可能在 CPU/硬盘上出先奇怪的东西(第一个层不要太大,不然会发生奇怪的事情)。
  • load_checkpoint_and_dispatch()load_checkpoint_in_model() 不会对 checkpoint 中的 state_dict 和模型进行检测(后面会修正这个),如果加载的 checkpoint 和 model 不匹配,又会发生奇怪的事情(报错啦!)。
  • 没有做模型并行,意思是你的模型被分割到多个 GPU ,然后被顺序执行。换言之,一次只有一个 GPU 在运行,然后等待其他 GPU 运行完毕。

torch_dtype

这个形参可以设置模型中全部 Linear 层的数据格式,可以使用如下把 \(32\) 位的模型线性层参数转换为 \(16\) 位的模型参数

model = AutoModelForCausalLM.from_pretrained("./Llama-2-7b-hf", torch_dtype=torch.float16)

但是除线性层之外的所有参数仍为 \(32\) 位浮点数。

bitsandbytes (load_in_8bit / load_in_4bit)

主要参考 HuggingFace 的文档 Quantize Transformers models

这个是 bitsandbytes 包中特有的内容,必须安装了这个包才能使用这个形参。

具体作用是使用 \(8\) 位或 \(4\) 位来读入模型参数。具体算法可以参考 LLM.int8() 这篇论文。

自 0.39.0 版发布以来,可以利用 FP4 数据类型,使用 \(4\) 位量化加载任何支持 device_map 的模型。

如果你想量化自己的 PyTorch 模型,请查看 Accelerate 库的文档。

以下是使用 bitsandbytes 集成可以做的事情:

正常用法

如果你的模型支持使用 HF 的 Accelerate 导入,并包含线性层,那么可以在调用 from_pretrained() 函数时使用 load_in_8bit 或者 load_in_4bit

from transformers import AutoModelForCausalLM

model_8bit = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", load_in_8bit=True)
model_4bit = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", load_in_4bit=True)

默认情况下其他的模块(例如 torch.nn.LayerNorm)会被转化为 torch.float16,但是其实你也可以使用上文中提及的 torch_dtype 强行改成 \(32\) 位。

import torch
from transformers import AutoModelForCausalLM model_8bit = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", load_in_8bit=True, torch_dtype=torch.float32)
model_8bit.model.decoder.layers[-1].final_layer_norm.weight.dtype

可以检查你的模型的内存占用量 footprint 使用 get_memory_footprint 方法.

print(model.get_memory_footprint())

4/8 位浮点数

需要三个最新的包:bitsandbytes、accelerate、transformers。

目前无法将 \(4\) 位的模型上传到 HF 的 Hub 里,\(8\) 位的模型如果使用最新的包可以上传到 Hub 中。训练 \(4/8\) 位的模型目前仍不支持。用 \(4/8\) 可以训练额外的参数。(截止至 2023-9-7)

利用 device_map、torch.dtype、bitsandbytes 压缩模型参数控制使用设备的更多相关文章

  1. MXNET:深度学习计算-模型参数

    我们将深入讲解模型参数的访问和初始化,以及如何在多个层之间共享同一份参数. 之前我们一直在使用默认的初始函数,net.initialize(). from mxnet import init, nd ...

  2. 深度学习方法(七):最新SqueezeNet 模型详解,CNN模型参数降低50倍,压缩461倍!

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 继续前面关于深度学习CNN经典模型的 ...

  3. 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数

    # View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...

  4. 深度学习网络压缩模型方法总结(model compression)

    两派 1. 新的卷机计算方法 这种是直接提出新的卷机计算方式,从而减少参数,达到压缩模型的效果,例如SqueezedNet,mobileNet SqueezeNet: AlexNet-level ac ...

  5. 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型

    人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...

  6. 支持向量机(SVM)利用网格搜索和交叉验证进行参数选择

    上一回有个读者问我:回归模型与分类模型的区别在哪?有什么不同,我在这里给他回答一下 : : : : 回归问题通常是用来预测一个值,如预测房价.未来的天气情况等等,例如一个产品的实际价格为500元,通过 ...

  7. 『MXNet』第三弹_Gluon模型参数

    MXNet中含有init包,它包含了多种模型初始化方法. from mxnet import init, nd from mxnet.gluon import nn net = nn.Sequenti ...

  8. 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练

    卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...

  9. TensorFlow Object Detection API中的Faster R-CNN /SSD模型参数调整

    关于TensorFlow Object Detection API配置,可以参考之前的文章https://becominghuman.ai/tensorflow-object-detection-ap ...

  10. pytorch_模型参数-保存,加载,打印

    1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...

随机推荐

  1. 记联软 UniAccess 导致 NSIS 安装包启动进程失效

    本文记录联软 UniAccess 注入的 C:\Window\LVUAAgentInstBaseRoot\syswow64\MozartBreathCore.dll 导致 NSIS 安装包启动进程失效 ...

  2. dotnet 使用 ConfigureAwait.Fody 库设置默认的 await 同步上下文切换配置

    在 dotnet 里面,使用 await 进行异步逻辑,默认是会尝试切换回调用 await 的线程同步上下文.这个机制对于大多数的上层应用来说都是符合逻辑且方便的逻辑,例如对于带 UI 线程的 WPF ...

  3. python之爬虫基础

    1.爬虫概念 其实就是模拟浏览器发送请求获取相应的数据 1.模拟请求 2.获取数据 3.筛选数据 4.保存数据 爬虫仅仅是将浏览器可以访问到的数据通过代码的方式加速访问 用于更加快速的获取数据,提升工 ...

  4. 修改element,vant,mint等ui框架的样式

    vant和mint移动端常见,引入单独的css文件,在main.js中引入下即可,直接在对应的vue文件的css通过控制台查看中修改也行,再不济加!important element: 1.vue框架 ...

  5. vue使用vant的van-tabs+tag在选项卡展示该内容有几条的提示

    1.直接写用v-if判断下标展示,会滚动.pass! 2.定位,各种定位,相对各种父元素各种定位,还是会滚,因为tab内容一定滚动,pass 3.手写选项卡+v-if判断,这肯定可行,但本着能用ui组 ...

  6. uniapp+vue3聊天室|uni-app+vite4+uv-ui跨端仿微信app聊天语音/朋友圈

    原创研发uniapp+vue3+pinia2跨三端仿微信app聊天模板Uniapp-Wechat. uni-vue3-wchat基于uni-app+vue3+pinia2+uni-ui+uv-ui等技 ...

  7. Mybatis-plus把List数据分页

    一.编写工具类: /** * @project * @Description 多表联查-分页 * @Author songwp * @Date 2022/8/8 10:31 * @Version 1. ...

  8. get pull报错 Please commit your changes or stash them before you merge

    当本地分支和远程修改了同一个文件代码,pull远程分支的代码的时候会出现文件冲突 出现这个错误 Please commit your changes or stash them before you ...

  9. grads 同时读取多个ctl文件方法

    1.不同的文件进行不同的设置:'set dfile 2' 2.读取不同文件的变量:qv.2 实例如下:'reinit''open e:\tskt.CTL''open e:\uwnd.CTL''open ...

  10. C语言:如何让printf输出更加美化(用游戏英雄属性作例子)

    #include <stdlib.h> /* run this program using the console pauser or add your own getch, system ...