总览

主要以PPO为基础来学习VeRL的整体训练流程. 在PPO里主要有4个模型:

  • Actor Model: 要训练的目标模型.
  • Critic Model: 用于在RL训练中评估总收益, 在训练过程中需要进行参数更新
  • Reference Model: SFT完的freeze模型, 不更新. 主要作用是作为base模型避免RL训练偏离SFT太远导致效果变差
  • Reward model: 用于在RL中评估即时收益, 也是参数freeze不更新. 这个模型存在的意义是, Critic产出的总收益\(V_t = R_t + \gamma V_{t+1}\) 是这么计算的, 在某步的总收益如果全靠Critic凭空预估, 精确度肯定不如使用已知事实数据去逼近. \(R_t\) 的作用就是提供已知即时收益.

训练步骤主要分成3步:

  1. Generation: Actor 在一批prompt样本上进行forward推理

  2. Preparation: Critic/Reward/Reference 分别通过一次前向计算,对Actor的结果进行评分, 计算各个token的Reward和KL散度.

    使用Critic/Reward的输出计算GAE, 把这部分的推理结果和评分放到经验池里, 再通过采样拿到用于下一轮训练的minibatch

  3. Training: 用产出的训练样本更新 actor 和 critic 模型

VeRL和其他框架的区别:

  • Single Controller: 优点在于协调执行顺序灵活, 分配管理资源映射灵活. 缺点在于LLM场景下连接庞大数量的worker会有明显的调度开销.
  • Multi Controller: 目前megatron/deepspeed采用的模式, 调度开销很低. 但因为没有集中控制节点各个数据流配置非常不灵活, 修改其中的一个节点需要变更其他节点的实现.
  • VeRL: 采用的方案是在每个模型节点内(比如Actor的所有worker视为1个节点)采用Multi的方式, 在节点间的数据流管理采用中心节点控制的方式.

RLHF的特点

Heterogeneous model workloads: actor、critic、ref和reward四类模型有差异较大的显存占用和计算要求.比如ref和reward只需要推理, 可以只存模型参数, 但actor和critic除了推理还需要训练, 参数/梯度/optimizer都需要load进显存. 另外每个模型的规模也不一样, 比如可以用超大的critic/reward来对齐一个小的actor模型, 所以每个类型需要设置不同的并行策略和优化方案.

Unbalanced computation between actor training and generation: 这个很好理解, actor模型在training阶段是计算密集型的, 而在generation阶段是纯推理, 如果和training阶段同样的并行设置(模型并行开的很大), 其实变成了内存密集型导致资源利用率变低. 但是改成不同的并行策略会带来参数通信的额外开销, 需要综合考量.

Diverse model placement requirements: 根据数据依赖和负载的关系, 把不同类型的model放到不同的device上, 从而可以并行执行. 比如下图, 把ref和reward放一块串行推理, 同时和critic推理并行执行. 如果耗时接近可以做到overlap. 但是放置策略需要配合算法设计, 尽可能的避免GPU空窗期提高利用率.

框架架构

图4描述了HybridFlow的架构,它由三个主要组件组成:混合编程模型、3D-HybridEngine和自动映射算法。混合编程模型包括一组分层API,以实现RLHF数据流的灵活表达和数据流中模型的高效计算(§4)。3D-HybridEngine是专门为Actor模型的高效训练和生成而设计的,允许在两个阶段采用不同的三维并行配置,并在两个阶段之间的转换过程中实现零内存冗余和最小化通信开销(§5)。自动映射算法确定每个模型的优化设备放置,以最大限度地提高RLHF的吞吐量(§6)。

User Input

  1. ModelConfig: Actor/Critic/Ref/Reward Model模型结构
  2. DeviceConfig: 模型在device上的放置配置, 根据这个配置再通过AutoMapping实现物理device的分配
  3. DataFlow graph: 各个model的并行策略配置. 中心节点通过这些初始化RLHF数据流, 并把这些并行operation和model分配到对应的device上.

ParallelWorker

每个model内部的MultipleController实现. 通过调用各个子model对应的训练/推理引擎来完成model的执行. 比如Ref/Reward只需要推理, 可以通过vllm/sglang等推理引擎实现高效率的推理. Actor/Critic通过Megatron等训练框架完成训练.

TransferProtocol

各个model之间的数据通信, 例如Ref/Reward infer出的结果用于经验回放样本的构建, 从而支持Actor/Critic模型训练. 通过@register装饰update函数, 每个协议包括一个collect接口和一个distribute接口, 以3D_PROTO为例:

  • Collect: 收集函数的返回数据回中心节点, 例如update_actor产出的loss张量 (类似于coordinator的metrics收集)
  • distribute: 派发各个DP中函数的输入数据, 例如update_actor的输入

对于包含data resharding(并行策略不一致)的场景, 流程图如下:

  1. 中心节点向actor发generate请求
  2. 看着像是actor通过allgather的方式把所有数据产出的future汇聚到一起? future怎么能通信得细看代码..目前想不到原理, allGather待确认
  3. 汇聚好的future发给中心节点
  4. 中心节点给critic发Prepare请求
  5. critic通过scatter的方式把future分配给各个DP组. scatter待确认
  6. 等future完成, 异步从Actor拉取推理结果. 因为DP从3个变成2个, 所以critic的每个DP相当于要推理Actor的一个半DP的结果.

3D-HybridEngine

主要解决的问题是上一节提到的, Actor在training和generation阶段, 因为推理和训练的不同特性导致需要配置不同的并行策略. 这个类的主要功能就是保证两个阶段相互切换时尽量少的通信与冗余数据的存储.

下面这张图对整个流程描述的非常清楚, 而且这个优化和zero++那个量化通信的allToall优化实现确实是有异曲同工之妙hh, 详细步骤:

之前分析RLHF特点的时候说了, 训练阶段计算密集型模型并行大,推理阶段模型并行小. 在这张图里, 训练阶段的配置是TP4,DP2, 而推理是DP4,TP2.

Hybridflow-v: 采用的方法是在train完成后, 通过allgather的方式把TP里的全量参数完成通信, 然后在gen的时候每个device抛弃掉自己不需要的那部分. 但存在的问题是在G3上训练需要part3的参数, 但推理不需要, 如果推理的这部分参数释放掉就需要再进行一次集合通信. 为了避免这个情况通过冗余参数的方式给存储下来.

Hybridflow: 通过把DP进行进一步拆分, 把microDP内的进行allgather后, 再把不同microDP组合到一起就能在避免冗余存储的同时还进一步缩减通信. 只是过程中会动态更改每个group内的节点rank.

参考

论文:https://arxiv.org/pdf/2409.19256

PPO: https://zhuanlan.zhihu.com/p/677607581

veRL代码阅读-1.论文原理的更多相关文章

  1. 【生成对抗网络学习 其三】BiGAN论文阅读笔记及其原理理解

    参考资料: 1.https://github.com/dragen1860/TensorFlow-2.x-Tutorials 2.<Adversarial Feature Learning> ...

  2. 脚本病毒分析扫描专题2-Powershell代码阅读扫盲

    4.2.PowerShell 为了保障木马样本的体积很小利于传播.攻击者会借助宏->WMI->Powershell的方式下载可执行文件恶意代码.最近也经常会遇见利用Powershell通过 ...

  3. 脚本病毒分析扫描专题1-VBA代码阅读扫盲、宏病毒分析

    1.Office Macor MS office宏的编程语言是Visual Basic For Applications(VBA). 微软在1994年发行的Excel5.0版本中,即具备了VBA的宏功 ...

  4. Jafka Broker代码阅读之总览

    从本文开始,笔者将尝试从源码角度解读Jafka(Kafka)的特性,探究其背后的实现原理与技术.前面讲解Jafka Broker的文章中有提到下面这段启动服务端的代码,我们就从这里开始. Proper ...

  5. C++11的简单线程池代码阅读

    这是一个简单的C++11实现的线程池,代码很简单. 原理就是管理一个任务队列和一个工作线程队列. 工作线程不断的从任务队列取任务,然后执行.如果没有任务就等待新任务的到来.添加新任务的时候先添加到任务 ...

  6. 代码阅读分析工具Understand 2.0试用

    Understand 2.0是一款源代码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实可以大大提高代码阅读效率.由于Understand功能十分强大,本文不可能详尽地介绍它的所有功能,所 ...

  7. Android 上的代码阅读器 CoderBrowserHD 修改支持 go 语言代码

    我在Android上的代码阅读器用的是 https://github.com/zerob13/CoderBrowserHD 改造的版本,改造后的版本我放在 https://github.com/ghj ...

  8. Linux协议栈代码阅读笔记(二)网络接口的配置

    Linux协议栈代码阅读笔记(二)网络接口的配置 (基于linux-2.6.11) (一)用户态通过C库函数ioctl进行网络接口的配置 例如,知名的ifconfig程序,就是通过C库函数sys_io ...

  9. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  10. 图形化代码阅读工具——Scitools Understand

    Scitools出品的Understand 2.0.用了很多年了,比Source Insight强大很多.以前的名字叫Understand for C/C++,Understand for Java, ...

随机推荐

  1. MQ 如何保证数据一致性?

    前言 上个月,我们有个电商系统出了个灵异事件:用户支付成功了,但订单状态死活不改成"已发货". 折腾了半天才定位到问题:订单服务的MQ消息,像人间蒸发一样消失了. 这个Bug让我明 ...

  2. C 冒泡排序和选择排序

    冒泡排序 理论概念: 从第一个数开始,将相邻的两个数比较,第一个数和第二个数比较,比如说是从小到大的排序,要是后面的数比前面的小则交换两个的位置,这样第一轮比较基数后最大的数就到了最后面,接着进行第二 ...

  3. 关于wireshark抓包工具抓取登录数据的一点心得

    研究这个软件很久了,一直处于门外汉状态,今天终于用它抓到点有用的东西,做个简单的笔记吧,后面再继续完善. 最近研究跨域自动登录时一直不太顺利,今天就仿照网上前辈们的方法,用wireshark先抓一下手 ...

  4. requirejs的简单使用,requirejs报错Uncaught Error: Mismatched anonymous define() module: …

    requirejs的简单使用 define()方法的3个参数: 参数1为模块名称(不填则以当前js的文件名定义一个匿名模块), 参数2为依赖项数组(可不填), 参数3为模块的实现 引入jQuery: ...

  5. Java编程--简单的Factory程序(工厂设计模式)

    Factory类不是接口.抽象类,就是普通的类. Factory就像一个工厂一样,可以返回很多对象. 子类在继承.实现抽象类和接口后由Factory类处理,由于子类可能会有多个,Factory根据客户 ...

  6. 使用DbUtils和dbcp连接池写的通用的CRUD工具类

    目录 1 项目目录结构 2 工具类需要的jar包 2.1 Dbutils需要的jar包 2.2 dbcp需要的jar包 2.3 数据库jar包 3 代码部分 3.1 dbcp.properties 3 ...

  7. 【经验】Git仓库多账号管理与部署|SSH密钥设置

    生成 SSH 密钥 先打开一个git窗口,生成ssh密钥. 如果打开的不是git窗口,而是cmd窗口,则需要先切换到C:\Users\用户名\.ssh目录下. 下面这条指令的your_email和yo ...

  8. 『Plotly实战指南』--在科学数据可视化中的应用(下)

    科学数据往往涉及多个维度,例如分子结构中的空间坐标.物理实验中的时间序列以及化学反应中的温度变化等. 传统的二维可视化方法已经难以满足这些复杂数据的展示需求. 而Plotly,作为一种强大的可视化库, ...

  9. 信息资源管理综合题之“SPD属于知识管理工具那一类 与 管理工具与知识库的区别 以及 使用知识地图是否可以用SynchroFLOW替代”

    一.案例:1995年10月,微软开发了一项"技能规划与开发(SPD)"的计划,他们把每个系统开发人员的工作能力和这些特定工作需要的知识制作成地图,让那个员工与团队间的配合更加默契, ...

  10. P2779 [AHOI2016初中组] 黑白序列题解

    题意: 小可可准备了一个未完成的黑白序列,用 B 和 W 表示黑色和白色,用 ? 表示尚未确定. 他希望知道一共有多少种不同的方法,在决定了每一个 ? 位置的颜色后可以得到一个小雪喜欢的黑白序列. 其 ...