算法优化

并行注意力机制

\[串行版本: y = x + MLP(LayerNorm(x + Attention(LayerNorm(x))))
\]
\[并行版本: y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))))
\]

乍一看确实不是等价的, attention那块的后置mlp去哪了..这个其实没有理论证明, Palm论文里提到把mlp融合到attention里实验62B模型上性能没有下降. 主要对应的是下图网络结构的并行化改造.

滑动窗口Attention

通过堆叠不同大小的窗口来捕获句子中的信息,所需要的计算量会比直接计算整个输入文本长度的计算量要小很多

滑动窗口attention的原理参考这个文章的解释:因为模型都是多层叠加的,所以层级越高,attend的视野域就越广。如果w=3,那么第一层只能注意3个位置,但到第二层能注意到第一层输出的三个位置,换算到第一层的输入,就是5个位置。所以随着层级越高,理论上每个位置注意到的区域就越大,所能存储的信息就越接近全局attention时的状态

AdamW优化(LAMB):

adamW对比adam是把权重衰减项从梯度的计算中拿出来直接加在了最后的权重更新步骤上, 为了把权重衰减和梯度计算解耦(如果加到梯度计算里会影响到动量的滑动平均), 从而提升优化效果.

这里做的优化是新增了一个 \(\phi\)截断函数, 主要目的是为了防止batch_size太大的时候导致优化过程中动量出现极端值影响bp. 这个方法论文里说可以把batch_size增大4倍从而加速训练.

\[W_t\leftarrow W_{t-1}-\alpha \cdot\phi(\frac{||W_{t-1}||}{||r_t+\lambda W_{t-1}||})(r_t+\lambda W_{t-1})
\]

3D并行优化

张量并行优化

序列并行(SP)主要有2个目的: 平摊LayerNorm和Dropout的计算开销, 而且Activation占用显存也很多, 能够平摊显存消耗.

[!NOTE]

这里有个疑问: LayerNorm不是要算全局均值和方差么..这个拆分后是只算该设备内部的均值还是说需要进行额外的allReduce?

AllGather优化

序列并行(SP)后, 在进行张量并行(TP)前需要在fp的时候需要先通过gather把之前层的切片从其他节点copy汇聚过来. 如果等gather完成再跑mlp和attention就会让gpu在通信这段时间空置等待, 这里可以优化成每通信完成一个切片后, 进行这个切片的MLP列切分计算, 同时直接把gather结果送给attention并行计算, 最后再把切片计算结果concat到一起. 比如在copy完A0后, A0的前向计算就和A1的通信并行起来了, 这样就能尽量的隐藏通信

另外对矩阵做切片后再进行矩阵乘法, 计算效率要也比2个超大的矩阵乘法要高.

Reduce-Scatter优化

这块是需要把汇聚计算完成的tensor在重新进行切分发送到序列并行的节点里, 这里是把MLP的第二次行切分和attention结果加和给merge到了一起, 完成一个切片的计算后就发送出去, 同步进行下一个切片的计算使计算和通信异步进行.

流水线优化

回顾一下交错式1F1B, 每个节点fp前需要等recv之前layer的结果, 在当前层fp完后, 通过allGather send出去计算完成的数据, 在bp的时候需要通过Reduce-scatter发送出去计算完的grad.

在warm-up/cool-down过程里, 都是必须等通信完成才能进行计算的. 为了缩短等待时间megascale把allGather的recv/send拆分开, recv优先级高于send, recv后就能直接开始计算, 不需要等send的长尾. 从而缩短等待时间.

在稳定状态的时候应该和megatron一样, 通信都会和计算异步. 实际情况里通信一般都会被隐藏掉(这里我没看懂为啥上面画的对比图是个纯串行的流程)

数据加载优化

这章的主要思想工作中经常用到就不细看了, 主要有2部分:

  1. 在bp完同步梯度的时候, 所有前向相关的数据就没用了, 就可以直接释放回池预加载下一轮fp需要的embed
  2. 避免单机内多张卡重复读相同的冗余数据(这里可能指的是embed集合么?), 先在内存里去好重再copy到显存

网络通信优化

TODO待补充..网络这块基本都忘完了.

集群容错

错误检测

主要思想和flux-cpu有很多相似点, 主要有以下几个点

  1. 每个worker定期上报心跳给中心节点, 确保当前状态正常
  2. 状态异常时的自动化诊断(NCCL allToAll, allReduce. 同主机RDMA网卡间的连接和带宽, 网卡到GPU/MEM的连接和带宽), 完成诊断后上报给中心节点.
  3. 中心节点向k8s申请失败节点的拉黑和重分配替换

状态恢复

  • checkpoint保存: 这个看着实现方法和async_patch是一样的, 先把参数copy到内存, 模型继续训练. 同步再起一个异步线程用来把内存里的参数写到hdfs. 这样就可以把非常耗时的hdfs写入给隐藏掉.
  • checkpoint读取: 主要优化手段是在同一数据并行组里的卡, 只选一个GPU对应的训练线程读hdfs后写内存, 然后通过broadcast给这个数据并行组里的其他卡. 可以降低hdfs的读取压力.

LLM的状态恢复感觉还挺复杂的, 如果有一个节点挂了在重分配后是所有节点全部回滚到上一个checkpoint还是有更快的方法..pipeline并行应该是在根据节点rank在启动的时候就分好了层, 节点重入后要替换原来的rank_id.

状态监控

基于cuda_event的timeline可视化, 算是老熟人了. 这里的难点感觉在于超多卡的实时日志收集, 根据DP来画出卡和卡的数据流依赖关系

参考:

megascale: https://arxiv.org/abs/2402.15627

Palm(并行attention): https://public.agent-matrix.com/publish/shared/Paper/Palm.pdf

滑动窗口注意力解释: https://zhuanlan.zhihu.com/p/223430086

LLM并行训练4-megascale论文学习的更多相关文章

  1. ML2021 | (腾讯)PatrickStar:通过基于块的内存管理实现预训练模型的并行训练

    ​  前言  目前比较常见的并行训练是数据并行,这是基于模型能够在一个GPU上存储的前提,而当这个前提无法满足时,则需要将模型放在多个GPU上.现有的一些模型并行方案仍存在许多问题,本文提出了一种名为 ...

  2. [源码解析] 分布式训练Megatron (1) --- 论文 & 基础

    [源码解析] 分布式训练Megatron (1) --- 论文 & 基础 目录 [源码解析] 分布式训练Megatron (1) --- 论文 & 基础 0x00 摘要 0x01 In ...

  3. PyTorch如何加速数据并行训练?分布式秘籍大揭秘

    PyTorch 在学术圈里已经成为最为流行的深度学习框架,如何在使用 PyTorch 时实现高效的并行化? 在芯片性能提升有限的今天,分布式训练成为了应对超大规模数据集和模型的主要方法.本文将向你介绍 ...

  4. Faster RCNN论文学习

    Faster R-CNN在Fast R-CNN的基础上的改进就是不再使用选择性搜索方法来提取框,效率慢,而是使用RPN网络来取代选择性搜索方法,不仅提高了速度,精确度也更高了 Faster R-CNN ...

  5. Fast RCNN论文学习

    Fast RCNN建立在以前使用深度卷积网络有效分类目标proposals的工作的基础上.使用了几个创新点来改善训练和测试的速度,同时还能增加检测的精确度.Fast RCNN训练VGG16网络的速度是 ...

  6. 《Explaining and harnessing adversarial examples》 论文学习报告

    <Explaining and harnessing adversarial examples> 论文学习报告 组员:裴建新   赖妍菱    周子玉 2020-03-27 1 背景 Sz ...

  7. 论文学习笔记 - 高光谱 和 LiDAR 融合分类合集

    A³CLNN: Spatial, Spectral and Multiscale Attention ConvLSTM Neural Network for Multisource Remote Se ...

  8. Pytorch:单卡多进程并行训练

    1 导引 我们在博客<Python:多进程并行编程与进程池>中介绍了如何使用Python的multiprocessing模块进行并行编程.不过在深度学习的项目中,我们进行单机多进程编程时一 ...

  9. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  10. BicycleGAN: Toward Multimodal Image-to-Image Translation - 1 - 论文学习,成对数据

    Abstract 许多图像到图像的翻译问题是有歧义的,因为一个输入图像可能对应多个可能的输出.在这项工作中,我们的目标是在一个条件生成模型设置中建立可能的输出分布.将模糊度提取到一个低维潜在向量中,在 ...

随机推荐

  1. 仿网易云音乐-微信小程序开发

    1.很多时候要找到完整的API接口很难,但网易云音乐的数据API是可以得到完整的. 安装API:https://github.com/Binaryify/NeteaseCloudMusicApi,只需 ...

  2. JavaScript面向对象的继承应用

    面向对象语言的三大特征:继承.封装.多态 <!DOCTYPE html> <html> <head> <title>Extend-OPP</tit ...

  3. 用 C 语言开发一门编程语言 — 字符串与文件加载

    目录 文章目录 目录 前文列表 字符串 读取字符串 注释 文件加载函数 命令行参数 打印函数 报错函数 源代码 前文列表 <用 C 语言开发一门编程语言 - 交互式解析器> <用 C ...

  4. Kubernetes集群中配置Ingress支持HTTPS访问(一):cfssl

    目录 一.系统环境 二.前言 三.对称加密和非对称加密简介 四.什么是HTTPS 五.Ingress简介 六.配置ingress对外发布服务 6.1 安装NGINX ingress controlle ...

  5. mvc5接口报错:The JSON request was too large to be deserialized的一种原因

    是mvc5版本的接口,接口使用了dynamic接收数组,json对象数组只有56个,length长度不到10万,但是提交就报The JSON request was too large to be d ...

  6. MyBatis数据源模块源码分析

    数据源对象是比较复杂的对象,其创建过程相对比较复杂,对于 MyBatis 创建数据源,具体来讲有如下难点: MyBatis 不但要能集成第三方的数据源组件,自身也提供了数据源的实现: 数据源的初始化参 ...

  7. Jenkins获取gitlab源代码

    Jenkins获取gitlab源代码 Jenkins权限获取 在日常工作做由于Jenkins启动用户是Jenkins,在执行脚本时系统命令是无法让Jenkins执行的,如果需要Jenkins权限有两种 ...

  8. 7.28考试总结(NOIP模拟26)[神炎皇·降雷皇·幻魔皇]

    或许只需一滴露水,便能守护这绽放的花朵. 前言 疯狂挂分,本来T2是想用树状数组优化一下的不知道为啥后来看了一下就少看了一层循环, 然后就想,我都 n 的复杂度了,足以搞过第一问了,还优化啥呀.... ...

  9. 7.14考试总结(NOIP模拟15)[夜莺与玫瑰·影子·玫瑰花精]

    梦总是有会醒来的时候,不会醒的梦总有一天会变成悲伤. 前言 这次考试的思维含量有一点大(此时距离考试还有 7min 而我的总结还没写完..) 但是对于以前的考试来讲还是有所进步的,毕竟在考试的时候还是 ...

  10. C++笔记(5)浅拷贝和深拷贝

    1. 定义 浅拷贝(shallow copy):多个对象共用同一块资源,同一块资源释放多次,崩溃或者内存泄漏 深拷贝(deep copy):每个对象共同拥有自己的资源,必须显式提供拷贝构造函数和赋值运 ...