[megatron代码阅读] 1. 初始化和组网
以pretrain_gpt.py为例, 看megatron的整体逻辑. 本章主要包括megatron初始化相关逻辑, 核心函数为initialize_megatron, setup_model_and_optimizer两个
initialize_megatron
parse_args
从argparse中直接读取超参数配置. 如学习率, 正则化等. 从环境变量中获取rank等
load_args_from_checkpoint
优先从未被持久化的ckpt加载, 并且只加载rank0的args
_load_non_persistent_base_checkpoint
find_checkpoint_rank_0
在不知道是否使用pp/ep策略的情况下, 尝试拼装出rank0 ckpt的名称, 如果存在就能定位到实际的存放目录
verify_checkpoint_and_load_strategy
根据是zarr还是 torch_dist选择不同的加载策略
TorchCommonLoadStrategy->torch.load()
如果没有非持久化的, 加载远端ckpt
从ckpt里的args替换掉之前解析的部分args, 比如tp/pp/vp等超参数
校验yaml/args, 全局变量设置
_initialize_distributed
pytorch里的get_world_size 返回的是gpu总卡数
初始化torch.distributed
mpu.initialize_model_parallel (并行设置,核心函数)
RankGenerator:
- 在每块GPU上启动一个进程(process),每个进程独立执行自己所维护的那部分模型的计算,实现并行训练
- 存储tp/pp/dp/ep/cp 各种并行度配置大小. 并且能够从 tp-dp str格式的并行配置里获取 tp/dp对应的mask和并行度大小设置.
get_ranks: 根据parallel_size和mask, 计算各种并行策略拆分后的rank group.
[!NOTE]
举例: 假定有2个8卡机器,node1: rank 0-7,node2: rank 8-15 tp-pp-dp: [2,4,2]
- _TENSOR_MODEL_PARALLEL_GROUP :[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]。
- _PIPELINE_MODEL_PARALLEL_GROUP : [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]。
- _MODEL_PARALLEL_GROUP :tp-pp = 2 * 4 = 8 [0, 1, 4, 5, 8, 9, 12, 13],[2, 3, 6, 7, 10, 11, 14, 15]
- _DATA_PARALLEL_GROUP :[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]。

注意在PP内输入层和输出层共享一个word_embedding,PP组中的第一个和最后一个rank需要通讯,保证word_embedding完全一致
group全局变量赋值: 每个并行模式有一个分组全局变量.通过 generator_wrapper生成, 自己的进程rank如果在group内, 初始化对应的nccl/gloo torch.distributed.new_group
GlobalMemoryBuffer: 保存每个已经分配出的tensor, 避免显存重分配.
setup_model_and_optimizer
主要逻辑是配置模型组网和优化器.
model_provider: torch gpt组网
megatron/core/transformer, transformer组网核心逻辑, 基于torch.nn.Module, 将涉及到的子模型结构进行了抽象. 通过subModule的方式嵌入自定义module, 便于代码复用
例如
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
)
在attention.py里读到之前moduleSpec中的对应linear_qkv的实现, 即TP列并行的Linear实现. 加上TransformerConfig, 就能定义出最终的网络逻辑. TP相关逻辑在后续专门看的时候再细写.
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
torch里实现module时, 主要关注__init__()和forward(), bp通过自动微分生成.
配置
配置类 ModelParallelConfig, TransformerConfig
ModelParallelConfig: 主要包括 模型并行/PP/通信overlap相关优化开关/cpuOffload 等相关配置
TransformerConfig: 主要包括 模型结构/MOE/算子fusion加速/激活重计算/Context并行 等配置
models/gpt/gpt_model.py
preprocess
分为word_emb和pos_emb两部分. 输出为 word_emb(b,s,h) + pos_emb(s,h) + tokentype_emb(b,s,h)(需要转置适配)
注意在embedding最后要进行dropout处理, 应该是为了减少模型过拟合的风险
WordEmbeddings
tensor_parallel.VocabParallelEmbedding
vocab_size表示词表维度, 例如分词预处理后保留能查到的几千个常用单词. 将vocab_size个embed均分存储到global_world_size张卡上, embedding lookup时从对应的存储卡上拉取. 这里把非自身rank的emb通过[start_idx, end_idx)的mask操作置0, 然后通过reduce就能获取完整的词表.
如果配置开了序列并行, reduce操作会变为reduceScatter操作, lookup之后直接分配好sp的输入.
RoPE(旋转位置编码)
位置编码需要满足几个性质: 1. 不能满足交换律, 第m个token与第n个token的位置关系,和第n个token与第m个token的位置关系一定要有区分度。 2.需要有远程衰减性

为了便于加速计算, 可以等价优化为下面这种向量乘法的形式:


tokentype_embedding
类型嵌入层,用于区分输入中不同类型的token, 例如,在BERT中用于区分两个句子,而在某些GPT变种或特定任务中可能用于区分不同类型的输入数据,如对话中的提问和回答.
transformer
self.decoder就是上面通过ModuleSpec获得的module, 可以根据配置选择普通的selfAttention, 还是MLA.
- MLA原理: 在模型能力不变基础上,通过KV低秩压缩, 使得推理的KVcache显存占用和计算效率上对比MHA性能有明显提升.

postprocess
1.output_layer & loss
训练时output可以并行, 这里是个TP列并行的方式, 训练方式如下例子:
<s>
<s> i
<s> i love
<s> i love maching
<s> i love maching learning <eos/>
训练阶段将这个矩阵直接输入到decoder,分别得到 5个输出 \(O_i, i\in [1,2,3,4,5]\), 理想的输出应该是[i, love, maching, learning, ] ,然后 比较\(O_i\)和理想输出的交叉熵,得到loss. 而且这五个序列可以放在一个batch内并行计算.
optimizer
_get_param_groups_and_buffers
从多个model_chunks中遍历所有的param向量, 对其中某些param进行特殊的处理
- decoupled_lr是为input/output layer单独设置的lr
no_weight_decay_cond: 配置参数是否应该执行权重衰减。- scale_lr_cond: 对某些指定层的参数进行学习率缩放, 匹配到对应的param_map后执行.
_get_megatron_optimizer_based_on_param_groups
主要逻辑是混合精度optimizer的设置(MixedPrecisionOptimizer), TODO: 细看Apex.FusedAdam, 和torch.adamW的区别在哪里
梯度缩放: DynamicGradScaler
混合精度训练的时候, 用于动态调整梯度缩放比例,以处理梯度爆炸或消失问题.
主要逻辑是有一个初始化scale值, 当连续hysteresis次迭代中出现NaN,torch.max(scale * backoff_factor, min_scale) 用来减小scale\(backoff\_factor \in (0, 1)\).
当连续growth_interval次没出现NaN, 按照_scale * growth_factor_, 放大scale, \(growth\_factor > 1\)
DistributedOptimizer
接口继承自torch.optimizer, 核心逻辑在step(self), 有3个类: FP32Optimizer, ChainedOptimizer, MixedPrecisionOptimizer
FP32Optimizer: fp32训练使用到的, 主要功能是配置了clip_grad后进行normalization, norm分两种, 一种是取max_grad, 一种是l2范数, 通过all_reduce拿到total_norm, 最后用这个值分别对每个param tensor进行scale. 在scale之后就调用的是torch.optimizer.step进行正常的Adam更新.
MixedPrecisionOptimizer: 混合精度训练使用
- prepare_grads: 先从param.grad copy到 param.main_grad, 这一步同时做了fp16->fp32的转换, 然后检查所有的grad, 先unscale, 再看是否存在NaN. 注意只有fp16需要, bf16不需要.
- clip_grad_norm: 与FP32Optimizer一样的方法scale grad.
- step_with_ready_grads: optimizer.step后, 再把fp32的main_param copy回用于下一轮bp的fp16 param里面.
ChainedOptimizer: 用于moe场景, 每个分块子模型配置不同的optimizer时使用. 多个optimizer之间串行执行.
下一节看megatron的模型保存&加载, 并行训练相关代码.
参考链接
[megatron代码阅读] 1. 初始化和组网的更多相关文章
- [置顶] Linux协议栈代码阅读笔记(一)
Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...
- Python - 关于代码阅读的一些建议
初始能力 让阅读思路保持清晰连贯,主力关注在流程架构和逻辑实现上,不被语法.技巧和业务流程等频繁地阻碍和打断. 建议基本满足以下条件,再开始进行代码阅读: 具备一定的语言基础:熟悉基础语法,常用的函数 ...
- Bleve代码阅读(二)——Index Mapping
引言 Bleve是Golang实现的一个全文检索库,类似Lucene之于Java.在这里通过阅读其代码,来学习如何使用及定制检索功能.也是为了通过阅读代码,学习在具体环境下Golang的一些使用方式. ...
- 脚本病毒分析扫描专题2-Powershell代码阅读扫盲
4.2.PowerShell 为了保障木马样本的体积很小利于传播.攻击者会借助宏->WMI->Powershell的方式下载可执行文件恶意代码.最近也经常会遇见利用Powershell通过 ...
- Jafka Broker代码阅读之总览
从本文开始,笔者将尝试从源码角度解读Jafka(Kafka)的特性,探究其背后的实现原理与技术.前面讲解Jafka Broker的文章中有提到下面这段启动服务端的代码,我们就从这里开始. Proper ...
- vnpy源码阅读学习(5):关于MainEngine的代码阅读
关于MainEngine的代码阅读 在入口文件中,我们看到了除了窗体界面的产生,还有关于MainEngine和EventEngin部分.今天来学习下MainEngine的代码. 首先在run代码中,我 ...
- 软光栅-uraster代码阅读(入门极品)
软光栅-uraster代码阅读(入门极品) 代码链接:https://github.com/Steve132/uraster 所有的代码都在uraster.hpp中.代码非常简单,适合初学者学习软光栅 ...
- Python代码阅读(第12篇):初始化二维数组
Python 代码阅读合集介绍:为什么不推荐Python初学者直接看项目源码 本篇阅读的代码实现了二维数组的初始化功能,根据给定的宽高初始化二维数组. 本篇阅读的代码片段来自于30-seconds-o ...
- Linux Kernel代码艺术——数组初始化
前几天看内核中系统调用代码,在系统调用向量表初始化中,有下面这段代码写的让我有点摸不着头脑: const sys_call_ptr_t sys_call_table[__NR_syscall_max+ ...
- 代码阅读分析工具Understand 2.0试用
Understand 2.0是一款源代码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实可以大大提高代码阅读效率.由于Understand功能十分强大,本文不可能详尽地介绍它的所有功能,所 ...
随机推荐
- Ubuntu安装Edge浏览器,好用的浏览器!!
秉持着简介的原则,我这里把重要的步骤记录下来,减少废话的使用量,大大缩短你们看的时间,好吧.. 步骤 首先,使用以下命令更新您的系统: sudo apt update 然后,使用以下命令安装Micro ...
- luogu P3842 [TJOI2007] 线段
link 好题,考虑如何设定状态. 设\(dp_{i,0/1}\)表示到了第\(i\)行走完后停在这一行的最左侧/最右侧. 设定\(l_i\)表示这一行该线段的最左侧,\(r_i\)表示这一行的最右侧 ...
- Java后端请求想接收多个对象入参的数据方法
在Java后端开发中,如果我们希望接收多个对象作为HTTP请求的入参,可以使用Spring Boot框架来简化这一过程.Spring Boot提供了强大的RESTful API支持,能够方便地处理各种 ...
- 一个专注推荐.Net开源项目的榜单
大家好,我是编程乐趣,从7月份开始推荐开源项目,已经推荐了接近100个开源项目了,其中绝大部分是有关.Net的开源项目,也受到大家非常多人的喜欢. 由于公众号不方便查询,很多人又想了解更多的开源项目, ...
- RL 基础 | 如何复现 PPO,以及一些踩坑经历
最近在复现 PPO 跑 MiniGrid,记录一下- 这里跑的环境是 Empty-5x5 和 8x8,都是简单环境,主要验证 PPO 实现是否正确. 01 Proximal policy Optimi ...
- DevEco Studio 实战第一节:字符串拼接与组件构建
DevEco Studio 实战第一节:字符串拼接与组件构建 引言 在现代软件开发中,TypeScript 提供了强类型的优势,而 DevEco Studio 作为华为推出的开发集成环境,提供了便捷的 ...
- 全网最适合入门的面向对象编程教程:60 Python面向对象综合实例-传感器数据实时绘图器
全网最适合入门的面向对象编程教程:60 Python 面向对象综合实例-传感器数据实时绘图器 摘要: 本文将结合之前内容实现模拟一个传感器系统软件,包括三个线程:传感器线程生成数据并通过串口发送给主机 ...
- ClickHouse 物化视图学习总结
物化视图 物化视图源表--基础数据源 创建源表,因为我们的目标涉及报告聚合数据而不是单条记录,所以我们可以解析它,将信息传递给物化视图,并丢弃实际传入的数据.这符合我们的目标并节省了存储空间,因此我们 ...
- WxPython跨平台开发框架之表格数据导出到Excel并打开
在 Python 中使用 wxPython 导出实体类列表数据到 Excel,通常可以借助 openpyxl 或 pandas 库来实现.本篇随笔由浅入深,逐步介绍导出Excel文件的操作,然后结合跨 ...
- Windows安装redis并将redis设置成服务开机自启
Redis 作为一种缓存工具,主要用于解决高并发的问题,在分布式系统中有着极其广泛的应用,Redis 本身是应用于 Linux/Unix 平台的(部署在服务器上边),官方并没有提供 Windows 平 ...