总览

主要以PPO为基础来学习VeRL的整体训练流程. 在PPO里主要有4个模型:

  • Actor Model: 要训练的目标模型.
  • Critic Model: 用于在RL训练中评估总收益, 在训练过程中需要进行参数更新
  • Reference Model: SFT完的freeze模型, 不更新. 主要作用是作为base模型避免RL训练偏离SFT太远导致效果变差
  • Reward model: 用于在RL中评估即时收益, 也是参数freeze不更新. 这个模型存在的意义是, Critic产出的总收益\(V_t = R_t + \gamma V_{t+1}\) 是这么计算的, 在某步的总收益如果全靠Critic凭空预估, 精确度肯定不如使用已知事实数据去逼近. \(R_t\) 的作用就是提供已知即时收益.

训练步骤主要分成3步:

  1. Generation: Actor 在一批prompt样本上进行forward推理

  2. Preparation: Critic/Reward/Reference 分别通过一次前向计算,对Actor的结果进行评分, 计算各个token的Reward和KL散度.

    使用Critic/Reward的输出计算GAE, 把这部分的推理结果和评分放到经验池里, 再通过采样拿到用于下一轮训练的minibatch

  3. Training: 用产出的训练样本更新 actor 和 critic 模型

VeRL和其他框架的区别:

  • Single Controller: 优点在于协调执行顺序灵活, 分配管理资源映射灵活. 缺点在于LLM场景下连接庞大数量的worker会有明显的调度开销.
  • Multi Controller: 目前megatron/deepspeed采用的模式, 调度开销很低. 但因为没有集中控制节点各个数据流配置非常不灵活, 修改其中的一个节点需要变更其他节点的实现.
  • VeRL: 采用的方案是在每个模型节点内(比如Actor的所有worker视为1个节点)采用Multi的方式, 在节点间的数据流管理采用中心节点控制的方式.

RLHF的特点

Heterogeneous model workloads: actor、critic、ref和reward四类模型有差异较大的显存占用和计算要求.比如ref和reward只需要推理, 可以只存模型参数, 但actor和critic除了推理还需要训练, 参数/梯度/optimizer都需要load进显存. 另外每个模型的规模也不一样, 比如可以用超大的critic/reward来对齐一个小的actor模型, 所以每个类型需要设置不同的并行策略和优化方案.

Unbalanced computation between actor training and generation: 这个很好理解, actor模型在training阶段是计算密集型的, 而在generation阶段是纯推理, 如果和training阶段同样的并行设置(模型并行开的很大), 其实变成了内存密集型导致资源利用率变低. 但是改成不同的并行策略会带来参数通信的额外开销, 需要综合考量.

Diverse model placement requirements: 根据数据依赖和负载的关系, 把不同类型的model放到不同的device上, 从而可以并行执行. 比如下图, 把ref和reward放一块串行推理, 同时和critic推理并行执行. 如果耗时接近可以做到overlap. 但是放置策略需要配合算法设计, 尽可能的避免GPU空窗期提高利用率.

框架架构

图4描述了HybridFlow的架构,它由三个主要组件组成:混合编程模型、3D-HybridEngine和自动映射算法。混合编程模型包括一组分层API,以实现RLHF数据流的灵活表达和数据流中模型的高效计算(§4)。3D-HybridEngine是专门为Actor模型的高效训练和生成而设计的,允许在两个阶段采用不同的三维并行配置,并在两个阶段之间的转换过程中实现零内存冗余和最小化通信开销(§5)。自动映射算法确定每个模型的优化设备放置,以最大限度地提高RLHF的吞吐量(§6)。

User Input

  1. ModelConfig: Actor/Critic/Ref/Reward Model模型结构
  2. DeviceConfig: 模型在device上的放置配置, 根据这个配置再通过AutoMapping实现物理device的分配
  3. DataFlow graph: 各个model的并行策略配置. 中心节点通过这些初始化RLHF数据流, 并把这些并行operation和model分配到对应的device上.

ParallelWorker

每个model内部的MultipleController实现. 通过调用各个子model对应的训练/推理引擎来完成model的执行. 比如Ref/Reward只需要推理, 可以通过vllm/sglang等推理引擎实现高效率的推理. Actor/Critic通过Megatron等训练框架完成训练.

TransferProtocol

各个model之间的数据通信, 例如Ref/Reward infer出的结果用于经验回放样本的构建, 从而支持Actor/Critic模型训练. 通过@register装饰update函数, 每个协议包括一个collect接口和一个distribute接口, 以3D_PROTO为例:

  • Collect: 收集函数的返回数据回中心节点, 例如update_actor产出的loss张量 (类似于coordinator的metrics收集)
  • distribute: 派发各个DP中函数的输入数据, 例如update_actor的输入

对于包含data resharding(并行策略不一致)的场景, 流程图如下:

  1. 中心节点向actor发generate请求
  2. 看着像是actor通过allgather的方式把所有数据产出的future汇聚到一起? future怎么能通信得细看代码..目前想不到原理, allGather待确认
  3. 汇聚好的future发给中心节点
  4. 中心节点给critic发Prepare请求
  5. critic通过scatter的方式把future分配给各个DP组. scatter待确认
  6. 等future完成, 异步从Actor拉取推理结果. 因为DP从3个变成2个, 所以critic的每个DP相当于要推理Actor的一个半DP的结果.

3D-HybridEngine

主要解决的问题是上一节提到的, Actor在training和generation阶段, 因为推理和训练的不同特性导致需要配置不同的并行策略. 这个类的主要功能就是保证两个阶段相互切换时尽量少的通信与冗余数据的存储.

下面这张图对整个流程描述的非常清楚, 而且这个优化和zero++那个量化通信的allToall优化实现确实是有异曲同工之妙hh, 详细步骤:

之前分析RLHF特点的时候说了, 训练阶段计算密集型模型并行大,推理阶段模型并行小. 在这张图里, 训练阶段的配置是TP4,DP2, 而推理是DP4,TP2.

Hybridflow-v: 采用的方法是在train完成后, 通过allgather的方式把TP里的全量参数完成通信, 然后在gen的时候每个device抛弃掉自己不需要的那部分. 但存在的问题是在G3上训练需要part3的参数, 但推理不需要, 如果推理的这部分参数释放掉就需要再进行一次集合通信. 为了避免这个情况通过冗余参数的方式给存储下来.

Hybridflow: 通过把DP进行进一步拆分, 把microDP内的进行allgather后, 再把不同microDP组合到一起就能在避免冗余存储的同时还进一步缩减通信. 只是过程中会动态更改每个group内的节点rank.

参考

论文:https://arxiv.org/pdf/2409.19256

PPO: https://zhuanlan.zhihu.com/p/677607581

veRL代码阅读-1.论文原理的更多相关文章

  1. 【生成对抗网络学习 其三】BiGAN论文阅读笔记及其原理理解

    参考资料: 1.https://github.com/dragen1860/TensorFlow-2.x-Tutorials 2.<Adversarial Feature Learning> ...

  2. 脚本病毒分析扫描专题2-Powershell代码阅读扫盲

    4.2.PowerShell 为了保障木马样本的体积很小利于传播.攻击者会借助宏->WMI->Powershell的方式下载可执行文件恶意代码.最近也经常会遇见利用Powershell通过 ...

  3. 脚本病毒分析扫描专题1-VBA代码阅读扫盲、宏病毒分析

    1.Office Macor MS office宏的编程语言是Visual Basic For Applications(VBA). 微软在1994年发行的Excel5.0版本中,即具备了VBA的宏功 ...

  4. Jafka Broker代码阅读之总览

    从本文开始,笔者将尝试从源码角度解读Jafka(Kafka)的特性,探究其背后的实现原理与技术.前面讲解Jafka Broker的文章中有提到下面这段启动服务端的代码,我们就从这里开始. Proper ...

  5. C++11的简单线程池代码阅读

    这是一个简单的C++11实现的线程池,代码很简单. 原理就是管理一个任务队列和一个工作线程队列. 工作线程不断的从任务队列取任务,然后执行.如果没有任务就等待新任务的到来.添加新任务的时候先添加到任务 ...

  6. 代码阅读分析工具Understand 2.0试用

    Understand 2.0是一款源代码阅读分析软件,功能强大.试用过一段时间后,感觉相当不错,确实可以大大提高代码阅读效率.由于Understand功能十分强大,本文不可能详尽地介绍它的所有功能,所 ...

  7. Android 上的代码阅读器 CoderBrowserHD 修改支持 go 语言代码

    我在Android上的代码阅读器用的是 https://github.com/zerob13/CoderBrowserHD 改造的版本,改造后的版本我放在 https://github.com/ghj ...

  8. Linux协议栈代码阅读笔记(二)网络接口的配置

    Linux协议栈代码阅读笔记(二)网络接口的配置 (基于linux-2.6.11) (一)用户态通过C库函数ioctl进行网络接口的配置 例如,知名的ifconfig程序,就是通过C库函数sys_io ...

  9. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  10. 图形化代码阅读工具——Scitools Understand

    Scitools出品的Understand 2.0.用了很多年了,比Source Insight强大很多.以前的名字叫Understand for C/C++,Understand for Java, ...

随机推荐

  1. 【Esp32】为 idf 定制本地 Arduino 组件

    在开始今天的水文前,老周先要奉劝一下国内某些嵌入式砖家和穴者,不要看不起 Arduino,它不是一种开发板,而是一种规范.Arduino 的思想是正确的,把各种开发板封装为统一的 API,让许多开源库 ...

  2. RealSense .bag文件彩色图,深度图提取

    RealSense .bag文件彩色图,深度图提取 代码 import roslib import rosbag import rospy import cv2 import os from sens ...

  3. ZeroTier简单使用

    在 CentOS 系统下,你可以使用以下命令行操作来管理 ZeroTier 网络和设备.首先,确保已经正确安装 ZeroTier 软件,你可以按照以下步骤进行安装: 安装 ZeroTier: Zero ...

  4. 【JDBC第3章】使用PreparedStatement实现CRUD操作

    第3章:使用PreparedStatement实现CRUD操作 3.1 操作和访问数据库 数据库连接被用于向数据库服务器发送命令和 SQL 语句,并接受数据库服务器返回的结果.其实一个数据库连接就是一 ...

  5. 【Java】Math类的基本操作

    Math类 Math 类是数学操作类,提供了一系列的数学操作方法,包括求绝对值.三角函数等,在 Math 类中提供的一切方法都是静态方法(类方法),所以直接由类名称调用即可. Math类的基本操作: ...

  6. halcon 入门教程(三) 边缘检测

    原文作者:aircraft 原文链接:halcon 入门教程(三) 边缘检测 有兴趣可以多看其他的halcon教程 halcon 学习教程目录 本篇讲一下边缘检测(边缘提取),因为这个我发现也是比较常 ...

  7. Hyperledger Fabric - 自定义network.sh脚本

    引言:依据hyperledger fabric提供的测试网络脚本搭建自己的网络环境 该系列参考:https://blog.csdn.net/ling1998?type=blog 执行./network ...

  8. 为什么 Java 8 移除了永久代(PermGen)并引入了元空间(Metaspace)?

    为什么 Java 8 移除了永久代(PermGen)并引入了元空间(Metaspace)? 在 Java 8 中,JVM 移除了 永久代(PermGen)并引入了 元空间(Metaspace),这一改 ...

  9. 记一次 .NET某旅行社酒店管理系统 卡死分析

    一:背景 1. 讲故事 年初有位朋友找到我,说他们的管理系统不响应了,让我帮忙看下到底咋回事? 手上也有dump,那就来分析吧. 二:为什么没有响应 1. 线程池队列有积压吗? 朋友的系统是一个web ...

  10. 特殊符号大全,特殊字符、emoji符号收藏,可复制直接使用

    收藏包含:特殊符号.emoji符号.编号序号.数学符号.上标下标.标点符号.货币符号.箭头符号.国旗符号等 ❥웃유☮☏☢☠♚▲♪✞÷↑↓◆◇⊙■□△▽¿─│❣♂♀☿Ⓐ✍☣☤✘☒♛▼♫⌘☪≈←→◈◎☉★ ...