自推测解码是一种新颖的文本生成方法,它结合了推测解码 (Speculative Decoding) 的优势和大语言模型 (LLM) 的提前退出 (Early Exit) 机制。该方法出自论文 LayerSkip: Enabling Early-Exit Inference and Self-Speculative Decoding。它通过使用 同一个模型 的早期层来生成候选词元 (token),并使用后期层进行验证,从而实现高效生成。

这项技术不仅加快了文本生成速度,还显著节省了内存并降低了计算延迟。为了实现端到端的加速,早期层的输出需要与最终层的输出足够接近。正如论文中所述,这可以通过一种训练方法来实现,该方法可以在预训练期间应用,也可以在特定领域进行微调时应用。自推测解码对于实际应用特别高效,它可以在较小的 GPU 上部署,并降低 大规模推理 所需的整体硬件资源。

在本博客中,我们将探讨自推测解码的概念、其实现方式以及在 transformers 库中的实际应用。您将了解到其技术原理,包括 提前退出层 (Early-Exit Layers)反嵌入 (Unembedding)训练修改 (Training Modifications)。为了将这些概念付诸实践,我们提供了代码示例、与传统推测解码的基准比较,以及对性能权衡的见解。

您还可以直接查看以下 Hugging Face 资源,了解更多关于该方法的信息并亲自尝试:

  1. Hugging Face 论文讨论论坛
  2. LayerSkip 模型集合
  3. 展示自推测解码深入工作原理的 Colab 笔记本

推测解码与自推测解码

facebook/layerskip-llama2-7B 上的 LayerSkip 推理演示 (使用 LayerSkip 方法持续预训练的 Llama2 7B)。

传统的推测解码 使用 两个 模型: 一个较小的模型 (草稿模型) 用于生成一系列候选词元,一个较大的模型 (验证模型) 用于验证草稿的准确性。较小的模型执行大部分生成工作,而较大的模型则负责改进结果。这提高了文本生成速度,因为较大的模型一次性验证完整序列,而不是逐个生成词元。

在自推测解码中,作者在此概念的基础上,使用大模型的早期层来生成草稿词元,然后由模型的更深层进行验证。这种推测解码的“自洽”特性需要特定的训练,使模型能够同时执行草稿生成和验证。这反过来又比传统的推测解码提高了速度并降低了计算成本。

transformers 中的使用

为了在 transformers 库中启用提前退出自推测解码,我们只需在 generate() 函数中添加 assistant_early_exit 参数。

以下是一个简单的代码片段,展示了该功能:

pip install transformers

from transformers import AutoTokenizer, AutoModelForCausalLM

early_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B" tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)

注意: 虽然 assistant_early_exit 参数可以为任何仅解码器的 transformer 启用提前退出自推测解码,但除非模型经过专门训练,否则无法反嵌入 (通过 LM 头进行解码的过程,在博客文章后面有描述) 中间层的 logits。只有对检查点进行这样的训练,以提高早期层的准确性,您才能获得加速。LayerSkip 论文提出了一种训练方法来实现这一点 (即应用提前退出损失,并逐步增加层丢弃率)。这里 提供了使用 LayerSkip 训练方法持续预训练的 Llama2、Llama3 和 Code Llama 检查点的集合。

基准测试

我们进行了一系列广泛的基准测试,以衡量 LayerSkip 的自推测解码相对于自回归解码在各种模型上的加速情况。我们还将自推测解码 (基于提前退出) 与标准推测解码技术进行了比较。要复现这些结果,您可以在 这里 找到代码,并在 此电子表格 中找到运行每个实验的命令。所有实验均在单个 80GB A100 GPU 上运行,除了 Llama2 70B 实验在 8 个 A100 GPU 的节点上运行。

Llama3.2 1B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
facebook/layerskip-llama3.2-1B 1 Early Exit @ Layer 4 summarization 1 1195.28 9.96 2147.7 17.9 1.80

Llama3 8B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-1B 1 summarization 9 1872.46 19.04 2859.35 29.08 1.53
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-3B 3 summarization 11 2814.82 28.63 2825.36 28.73 1.00
facebook/layerskip-llama3-8B 8 Early Exit @ Layer 4 summarization 8 1949.02 15.75 3571.81 28.87 1.83

Llama2 70B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-13b-hf 13 summarization 83 5036.54 46.3 12289.01 112.97 2.44
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-7b-hf 7 summarization 77 4357.55 40.06 12324.19 113.3 2.83
meta-llama/Llama-2-70b-hf 70 TinyLlama/TinyLlama_v1.1 1 summarization 71 4356.21 40.05 12363.22 113.66 2.84
facebook/layerskip-llama2-70B 70 Early Exit @ Layer 10 summarization 70 6012.04 54.96 1283.34 113.2 2.06

Llama2 13B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-13b-hf 13 meta-llama/Llama-2-7b-hf 7 summarization 20 3557.07 27.79 4088.48 31.94 1.15
meta-llama/Llama-2-13b-hf 13 TinyLlama/TinyLlama_v1.1 1 summarization 14 2901.92 22.67 4190.42 32.74 1.44
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-270M 0.27 summarization 13.27 2883.33 22.53 4521.12 35.32 1.57
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-450M 0.45 summarization 13.45 3267.69 25.53 4321.75 33.76 1.32
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 4 summarization 13 4238.45 33.11 4217.78 32.95 0.995
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 8 summarization 13 2459.61 19.22 4294.98 33.55 1.746

Llama2 7B

Model Variant (模型变体) Layers (层数) Assistant Model (辅助模型) Assistant Layers (辅助层数) Task (任务) Total Layers (总层数) FLOPs/Input (G) (输入 FLOPs) Time/Input (s) (输入时间) FLOPs/Output (G) (输出 FLOPs) Time/Output (s) (输出时间) Efficiency (效率)
meta-llama/Llama-2-7b-hf 7 TinyLlama/TinyLlama_v1.1 1 summarization 8 2771.54 21.65 3368.48 26.32 1.22
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-270M 0.27 summarization 7.27 2607.82 20.37 4221.14 32.98 1.62
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-450M 0.45 summarization 7.45 3324.68 25.97 4178.66 32.65 1.26
facebook/layerskip-llama2-7B 7 Early Exit @ Layer 4 summarization 7 2548.4 19.91 3306.73 25.83 1.297

我们可以观察到以下几点:

  • 从“ 总参数数量”列可以看出,自推测解码消耗的内存更少,因为它不需要单独的草稿模型,并且草稿阶段层的权重被重用。
  • 对于除 Llama2 70B 之外的所有模型大小和生成,提前退出自推测解码比常规的两模型推测解码更快。
  • 与其它模型相比,Llama2 70B 的自推测解码速度提升相对有限,可能有不同的原因,例如,Llama2 70B 的 LayerSkip 检查点持续预训练的 token 较少 (Llama2 70B 为 328M token,而 Llama2 7B 为 52B token)。但这是未来研究需要改进的一个方面。尽管如此,70B 的自推测解码明显快于自回归解码。

自生成和自验证

自推测解码过程从自生成开始,其中词元是通过从某个中间层提前退出来生成的。推测词元的数量定义了在此阶段生成多少草稿词元,而我们退出的层定义了草稿阶段的规模和准确性。这两个参数都可以在推理时根据草稿阶段的速度和准确性之间的权衡来指定。

下一步是自验证,其中使用完整模型来验证草稿词元。验证模型重用草稿模型中的缓存部分。如果草稿词元与验证的词元一致,则将它们添加到最终输出中,从而更好地利用我们系统中的内存带宽,因为使用完整模型生成一系列词元比验证草稿要昂贵得多,只要有几个词元匹配即可。

在自验证阶段,只有剩余的层才会被计算以进行验证,因为早期层的结果在草稿阶段已被缓存。

提前退出和反嵌入

自推测解码中的一项关键技术是提前退出,即生成过程可以在预先指定的层停止。为了实现这一点,我们通过将这些层的 logits 投影到语言模型 (LM) 头上来反嵌入它们,以预测下一个词元。这允许模型跳过后续层并提高推理时间。

可以在任何 transformer 层执行反嵌入,将提前退出转变为一种高效的词元预测机制。一个自然而然的问题出现了: 当 LM 头最初被训练为仅与最终层一起工作时,如何使其适应反嵌入较早层的 logits?这就是训练修改发挥作用的地方。

训练修改

在训练阶段,我们引入了层丢弃,它允许模型在训练期间跳过某些层。丢弃率在较深的层中逐渐增加,使模型不太依赖其后面的层,并增强模型的泛化能力并加快训练速度。

除了层丢弃之外,还应用了提前退出损失,以确保 LM 头学习反嵌入不同的层。使用每个出口 (中间层) 的归一化损失的总和来给出使用提前出口训练模型的总损失函数。这种技术通过在所有层之间分配学习任务来实现高效训练。

优化: 共享权重、共享 KV 缓存和共享计算

自推测解码显著受益于缓存重用,特别是 KV 缓存,它存储在草稿阶段计算的键值对。此缓存允许模型跳过冗余计算,因为草稿和验证阶段都使用相同的早期层。此外,退出查询缓存存储来自退出层的查询向量,允许验证从草稿阶段无缝继续。

与传统的双模型推测解码相比,提前退出自推测解码可以从以下节省中受益:

  • 共享权重: 为草稿和验证重用前 E 层 的权重。
  • 共享 KV 缓存: 为草稿和验证重用前 E 层的键值对
  • 共享计算: 通过使用仅保存退出层 E-1 的查询向量的退出查询缓存来重用前 E 层的计算,以便验证过程无需计算层 0 到 E-1。

KV 和退出查询缓存的组合称为 KVQ 缓存,可减少内存开销并提高推理延迟。

到目前为止, transformers 库已在此 pull request 中实现了第一个优化 (共享权重)。随着使用此方法的模型数量增加,我们将考虑其他优化。如果您有兴趣,请随时提出 PR!

提前退出层的选择策略

草稿阶段的提前退出层是一个超参数,我们可以在推理期间调整或修改:

  • 我们越早退出,生成草稿词元的速度就越快,但它们的准确性就越低。
  • 我们越晚退出,生成的草稿词元就越准确,但它们的速度就越慢。

我们编写了一个脚本来遍历不同的提前退出层并测量 A100 GPU 上的每秒词元数。在下面的表格中,我们绘制了针对不同 Llama 模型的 LayerSkip 和基线检查点的每秒词元数与提前退出层的关系图 (您可以在 此处 查看完整日志)。

Llama3.2 1B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Llama3 8B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Code Llama3 34B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Code Llama3 7B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Llama2 70B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Llama2 13B

Normal (常规模型) LayerSkip (LayerSkip 模型)

Llama2 7B

Normal (常规模型) LayerSkip (LayerSkip 模型)

我们可以观察到以下几点:

  • 对于没有使用 LayerSkip 训练方法进行预训练或持续预训练的基线检查点,提前退出自推测解码比自回归解码更慢。这是因为在大多数 LLM 的训练过程中,早期层并没有被激励去学习预测输出,因此使用早期层生成词元的接受率会非常低。
  • 另一方面,对于使用 LayerSkip 训练方法持续预训练的 Llama 检查点,提前退出自推测解码在至少一部分层中比自回归解码具有更高的加速比。
    • 对于大多数模型 (除了 Llama3.2 1B),当我们遍历各层时,我们注意到一个规律模式: 加速比在前几层较低,逐渐增加到一个最佳点,然后再次下降。
    • 提前退出层的最佳点是在预测的高准确性和生成词元的低开销之间达到最佳权衡时。这个最佳点取决于每个模型,也可能取决于提示或提示的领域。

这些观察为进一步的实验和探索提供了有趣的机会。我们鼓励读者在这些想法的基础上进行构建,测试变体,并进行自己的研究。这些努力可以带来有价值的见解,并为该领域做出有意义的贡献。

结论

LayerSkip 利用提前退出、层丢弃和缓存重用之间的协同作用,创建了一个快速高效的文本生成流程。通过训练模型从不同层反嵌入输出,并使用缓存优化验证过程,这种方法在速度和准确性之间取得了平衡。因此,它显著改善了大语言模型的推理时间,同时保持了高质量的输出。由于使用单个模型作为草稿和验证模型,它还比传统的推测解码技术减少了内存使用。

自推测是一个令人兴奋的领域,同一个 LLM 可以创建草稿词元并自我修正。其他自推测方法包括:

  • Draft & Verify: 其中草稿阶段涉及跳过预定的注意力和前馈层。
  • MagicDec: 其中草稿阶段使用 KV 缓存的子集,这对长上下文输入很有用。
  • Jacobi DecodingLookahead Decoding: 其中草稿阶段是一系列“猜测词元”,可以是随机的或从 n-gram 查找表中获得的。

英文原文: https://huggingface.co/blog/layerskip

原文作者: Aritra Roy Gosthipaty, Mostafa Elhoushi, Pedro Cuenca, Vaibhav Srivastav

译者: smartisan

LayerSkip: 使用自推测解码加速大模型推理的更多相关文章

  1. 优化故事: BLOOM 模型推理

    经过"九九八十一难",大模型终于炼成.下一步就是架设服务,准备开门营业了.真这么简单?恐怕未必!行百里者半九十,推理优化又是新的雄关漫道.如何进行延迟优化?如何进行成本优化 (别忘 ...

  2. DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍

    DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍 1. 概述 近日来,ChatGPT及类似模型引发了人工智能(AI)领域的一场风潮. 这场风潮对数字世 ...

  3. 无插件的大模型浏览器Autodesk Viewer开发培训-武汉-2014年8月28日 9:00 – 12:00

    武汉附近的同学们有福了,这是全球第一次关于Autodesk viewer的教室培训. :) 你可能已经在各种场合听过或看过Autodesk最新推出的大模型浏览器,这是无需插件的浏览器模型,支持几十种数 ...

  4. PowerDesigner 学习:十大模型及五大分类

    个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...

  5. PowerDesigner 15学习笔记:十大模型及五大分类

    个人认为PowerDesigner 最大的特点和优势就是1)提供了一整套的解决方案,面向了不同的人员提供不同的模型工具,比如有针对企业架构师的模型,有针对需求分析师的模型,有针对系统分析师和软件架构师 ...

  6. 华为高级研究员谢凌曦:下一代AI将走向何方?盘古大模型探路之旅

    摘要:为了更深入理解千亿参数的盘古大模型,华为云社区采访到了华为云EI盘古团队高级研究员谢凌曦.谢博士以非常通俗的方式为我们娓娓道来了盘古大模型研发的"前世今生",以及它背后的艰难 ...

  7. 文心大模型api使用

    文心大模型api使用 首先,我们要获取硅谷社区的连个key 复制两个api备用 获取Access Token 获取access_token示例代码 之后就会输出 作文创作 作文创作:作文创作接口基于文 ...

  8. offset新探索:双管齐下,加速大数据量查询

    摘要:随着offset的增加,查询的时长也会越来越长.当offset达到百万级别的时候查询时长通常是业务所不能容忍的. 本文分享自华为云社区<offset新探索:双管齐下,加速大数据量查询> ...

  9. AI大模型学习了解

    # 百度文心 上线时间:2019年3月 官方介绍:https://wenxin.baidu.com/ 发布地点: 参考资料: 2600亿!全球最大中文单体模型鹏城-百度·文心发布 # 华为盘古 上线时 ...

  10. 千亿参数开源大模型 BLOOM 背后的技术

    假设你现在有了数据,也搞到了预算,一切就绪,准备开始训练一个大模型,一显身手了,"一朝看尽长安花"似乎近在眼前 -- 且慢!训练可不仅仅像这两个字的发音那么简单,看看 BLOOM ...

随机推荐

  1. 第1章04节 | 常见开源OLAP技术架构对比

    https://zhuanlan.zhihu.com/p/266402829 1. 什么是OLAP OLAP(On-line Analytical Processing,联机分析处理)是在基于数据仓库 ...

  2. 【转载】Spring Cloud Gateway-过滤器工厂详解(GatewayFilter Factories)

    http://www.imooc.com/article/290816 TIPS 本文基于 Spring Cloud Greenwich SR2 ,理论支持 Spring Cloud Greenwic ...

  3. Qt/C++动态启用地图功能/地图拖曳/键盘操作/滚轮缩放/双击放大/连续缩放等

    一.前言说明 地图组件为了方便用户的操作,一般会满足各种需求场景,比如用鼠标拖曳地图,实体键盘按键上下左右移动,鼠标滚轮缩放地图大小,双击放大地图,这些常规的操作可以极大的方便用户操作,问题是,有时候 ...

  4. Qt/C++编写的mqtt调试助手使用说明

    一.使用说明 第一步,选择协议前缀,可选mqtt://.mqtts://.ws://.wss://四种,带s结尾的是走ssl通信,ws表示走websocket通信.一般选默认的mqtt://就好. 第 ...

  5. [转]BeanUtils.copyProperties使用总结以及注意事项

    1.前言开发过程中,讲一个对象的属性和值赋值到另一个对象上,大量使用了get.set方法,看着很臃肿,思考下肯定不只有我有这种想法,所以技术上肯定有方法能解决这个问题,所以查阅了一些资料发现了Bean ...

  6. JVM实战—8.如何分析jstat统计来定位GC

    大纲 1.使用jstat了解线上系统的JVM运行状况 2.使用jmap和jhat了解线上系统的对象分布 3.如何分析JVM运行状况并合理优化 4.使用jstat分析模拟的BI系统JVM运行情况 5.使 ...

  7. LOL(英雄联盟) API 接口

    /*LOL(英雄联盟) API 接口 By wgscd /*LOL(英雄联盟) API 接口 By wgscd QQ:1009374598 */ GET https://127.0.0.1:58182 ...

  8. verilog 编写猫狗过河实验

    源代码地址:https://github.com/penggeon/catanddog 效果演示见: https://www.bilibili.com/video/BV1n24y147S1 警告: 仅 ...

  9. docker搭建rabbitmq镜像集群

    Rabbitmq普通集群模式,是将交换机.绑定.队列的元数据复制到集群里的任何一个节点,但队列内容只存在于特定的节点中,客户端通过连接集群中任意一个节点,即可以生产和消费集群中的任何队列内容(因为每个 ...

  10. Linux 部署DVWA靶场

    Linux 部署DVWA靶场 DVWA是一款开源的网络安全漏洞实践平台,专为安全学习者设计.它涵盖了XXS.SQL注入.文件上传.文件包含.CSRF和暴力破解等多种安全漏洞环境,每个漏洞都有从简单到复 ...