参考verl对dapo的实现,首先咱们看一下入口.sh和.py文件,在./recipe/dapo/文件夹中有以下目录

.
├── config
│ ├── dapo_megatron_trainer.yaml
│ └── dapo_trainer.yaml
├── dapo_ray_trainer.py
├── main_dapo.py
├── prepare_dapo_data.sh
├── README.md
├── run_dapo_qwen2.5_32b.sh

整体的执行顺序:

  • main_dapo.py:数据加载初始化、初始化actor_rollout model、rm model,加载reward_manager
  • dapo_ray_trainer.py:RL训练流程
    • 对batch进行repeate,每个q采样n次
    • 记录每个采样的log,以及对应的reward_score 和 advantage
      • filter掉一个q的所有sample的score都是1或都是0,继续获取新的q进行采样,直到满足要求的batch的大小达到train_prompt_bsz。(值得注意的是,batch大小是gen_prompt_bsz=3*train_prompt_bsz,通过提高采样q的个数,避免满足要求的q不到train_prompt_bsz)。
    • 每mini_batch的data进行模型更新
      • 每micro_batch的data进行前向传播(token-mean loss)与梯度计算

具体代码实例:

main_dapo.py

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
""" import os
import socket import hydra
import ray
from omegaconf import OmegaConf from verl.trainer.ppo.reward import load_reward_manager
from verl.utils.device import is_cuda_available from .dapo_ray_trainer import RayDAPOTrainer @hydra.main(config_path="config", config_name="dapo_trainer", version_base=None)
def main(config):
run_ppo(config) #################################################################
# RL训练入口
#################################################################
def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
default_runtime_env = {
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
}
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
print(f"ray init kwargs: {ray_init_kwargs}")
ray.init(**OmegaConf.to_container(ray_init_kwargs)) try:
if (
is_cuda_available
and config.global_profiler.tool == "nsys"
and OmegaConf.select(config.global_profiler, "steps") is not None
and len(OmegaConf.select(config.global_profiler, "steps")) > 0
):
nsight_options = OmegaConf.to_container(
config.global_profiler.global_tool_config.nsys.controller_nsight_options
)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
finally:
if ray.is_initialized():
ray.shutdown() @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint from omegaconf import OmegaConf from verl.utils.fs import copy_to_local print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config) # download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path) # instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer tokenizer = hf_tokenizer(local_path)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none from verl.single_controller.ray import RayWorkerGroup #################################################################
# 加载actor worker
#################################################################
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker ray_worker_group_cls = RayWorkerGroup else:
raise NotImplementedError from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
} global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
} # we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id # reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id #################################################################
# 加载reward manager函数。用于根据data计算对应的reward score
#################################################################
reward_fn = load_reward_manager(
config,
tokenizer,
0,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
) # Note that we always use function-based RM for validation
val_reward_fn = load_reward_manager(
config,
tokenizer,
1,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer,
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) #################################################################
# 加载主要的DAPO RL训练类,并运行.fit()
#################################################################
trainer = RayDAPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit() if __name__ == "__main__":
main()

我们紧接着来看一下from verl.trainer.ppo.reward import load_reward_manager

配置文件中verl/recipe/dapo/run_dapo_qwen2.5_32b.sh给出了reward的类型

enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4)) # overlong soft
overlong_penalty_factor=1.0 reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \

verl.trainer.ppo.reward.py

def load_reward_manager(
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
) -> AbstractRewardManager:
"""
Load and initialize a reward manager based on the configuration. Args:
config: PPO trainer configuration object containing reward_model fields.
tokenizer: Tokenizer object used for processing text.
num_examine: Number of samples to examine.
**reward_kwargs: Additional keyword arguments for the reward manager. Returns:
An instance of the specified reward manager class.
""" # Try to get a custom reward function based on the configuration
# user defined reward manager can be registered in custom_reward_fn
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
# naive: NaiveRewardManager
# prime: PrimeRewardManager
# batch: BatchRewardManager
# dapo: DAPORewardManager
# Note(haibin.lin): For custom reward managers, please make sure they are imported and
# registered via `verl.workers.reward_manager.register`
# By default reward_manager is set to naive (NaiveRewardManager)
#################################################################
# 在这里加载具体的reward_manager
#################################################################
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name) if compute_score is None:
sandbox_config = config.reward_model.get("sandbox_fusion")
sandbox_url = sandbox_config.get("url") if sandbox_config else None
memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024)
if sandbox_url:
sandbox_manager = multiprocessing.Manager()
# Create a semaphore to control concurrent access to the sandbox
_concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
final_compute_score = partial(
default_compute_score,
sandbox_fusion_url=sandbox_url,
concurrent_semaphore=_concurrent_semaphore,
memory_limit_mb=memory_limit_mb,
)
else:
final_compute_score = default_compute_score #################################################################
# 这里的reward_manager_cls 其实是DAPO,
#################################################################
# Instantiate and return the reward manager with the specified parameters
return reward_manager_cls(
tokenizer=tokenizer,
num_examine=num_examine,
compute_score=final_compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs,
)

这里需要知道dapo的reward_manager_cls 具体是什么,因为reward需要batch数据才能计算,因此对于reward manager咱们先按下不表(其实dapo对应的reward_manager_cls是在verl/verl/workers/reward_manager/dapo.py),先去dapo_ray_trainer.py看一下batch是怎么采样的,再回来仔细阅读reward的具体计算方法。

dapo_ray_trainer.py

#################################################################
# RayDAPOTrainer继承于RayPPOTrainer
# fit()函数:执行dapo的训练,包括(1)动态采样(2)overlong soft reward计算(3)token-level loss
#################################################################
class RayDAPOTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
""" def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf from verl.utils.tracking import Tracking logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
) self.global_steps = 0
self.gen_steps = 0 # load checkpoint before doing anything
self._load_checkpoint() # perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
rollout_skip.wrap_generate_sequences() # add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") # we start from step 1
self.global_steps += 1
self.gen_steps += 1
last_val_metrics = None prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False timing_raw = defaultdict(float)
batch = None
#################################################################
# num_prompt_in_batch:记录filter后,std不等于0的q的个数,当模型更新后重新赋值为0
# num_gen_batches: 记录当前使用了多少个gen_batch,当模型更新后重新赋值为0
#################################################################
num_prompt_in_batch = 0
num_gen_batches = 0
#################################################################
# 正式开始训练,循环每个epoch后,循环每个gen_batch
#################################################################
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {} with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
) #################################################################
# new_batch 是DataProto类型(具体见verl/verl/protocol.py),
# new_batch.batch是TensorDict类型
# new_batch中q的数量是可训练batch大小的3倍(增加采样的batch的q的个数)
#################################################################
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if "multi_modal_data" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
# 从new_batch中提取对应的key,构建gen_batch
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
# 这里为什么要repeate呢,因为每个prompt要采样n次,所以repeat n次。这里的interleave=True
# gen_batch: (bsz, response_length),
# gen_batch_output: (bsz*n, response_length)
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, "red"):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None) # 这个advatange 可以先忽略。RMAX需要先计算 贪心采样的sample的logits作为后序adv计算的baseline
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, "red"):
gen_baseline_batch = deepcopy(gen_batch)
# 这里是贪心采样的baseline,do_sample = False
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) new_batch = new_batch.union(gen_baseline_output)
# compute reward model score on new_batch
rm_scores = None
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
rm_scores = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(rm_scores)
reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) keys_to_pop = set(gen_baseline_output.batch.keys())
if rm_scores is not None:
keys_to_pop.update(rm_scores.batch.keys())
new_batch.pop(batch_keys=list(keys_to_pop)) new_batch.batch["reward_baselines"] = reward_baseline_tensor del rm_scores, gen_baseline_batch, gen_baseline_output #################################################################
# new_batch的大小是gen_prompt_bsz
# 对每一个prompt设置一个专属的标识 uid
# 之所以设置uid,是因为之后对sample计算reward时,需要对同一个q的n个sample的reward标准化
#################################################################
new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# 对batch中的每个key进行repeat(这里应该主要是对uid进行repeat)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
# 把采样完的放到new_batch中
new_batch = new_batch.union(gen_batch_output) with marked_timer("reward", timing_raw, "yellow"):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor) # 计算new_batch各个采样的reward,根据设置好的self.reward_fn
# we combine with rule-based rm
reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) new_batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict:
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
) # compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(
new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(
kl_metrics
) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] #################################################################
# dapo的filter(dynamic sample)部分
#################################################################
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
) # {uid: [r1,r2,r3,...,rn], uid: [...], ...},记录每个轨迹所有采样的reward
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val) # 每个q的reward的std
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) # 保留reward std不是0的q的uid
kept_prompt_uids = [
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
# 累积std不是0的q
num_prompt_in_batch += len(kept_prompt_uids) # 记录留下来的q的sample的idx
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx) # 基于traj的id,检索对应的new_batch
new_batch = new_batch[kept_traj_idxs]
# batch是留下的traj数据的累积
batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) # .sh文件配置的 可以训练的batch的最小大小(q的数量)
prompt_bsz = self.config.data.train_batch_size
# 如果现有的累积filter出来的q的数量小于 配置的最小数量,则continue继续使用下一个new_batch进行累积
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
# max_num_gen_batches是最多可以使用的gen_batch的个数
# 如果其小于0的话,即没有限制;若num_gen_batches < max_num_gen_batches则继续continue
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
self.gen_steps += 1
is_last_step = self.global_steps >= self.total_training_steps
continue
else:
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}."
+ " Generated too many. Please check if your data are too difficult."
+ " You could also try set max_num_gen_batches=0 to enable endless trials."
)
# 累积的符合的q个个数>=最小的可以训练的batch的大小
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
#################################################################
# 对齐一下,多余的轨迹会被抛弃,不知道会不会导致采样的利用效率不高,
# 会不会导致一些轨迹根本不会被训练到
#################################################################
batch = batch[:traj_bsz] #################################################################
# actor模型更新
#################################################################
# === Updating === batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics) # compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() #################################################################
# 记录filter后的batch的每个traj的采样时的logtis(token-level)
# 用于计算重要性采样的比值
#################################################################
# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, "blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
# 这里dapo的loss_agg_mode是“token_mean”
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob) if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, "olive"):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob) # compute values
if self.use_critic:
with marked_timer("values", timing_raw, "cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values) # 计算token_level的重要性采样
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics) #################################################################
# 计算advantage
#################################################################
with marked_timer("adv", timing_raw, "brown"):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
) # update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, "pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics) # implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
#################################################################
# 更新actor model(batch的大小是train_prompt_size)
# 每个mini_bsz 更新一次模型(参数-累积梯度)
# 每个micro_bsz 累积一次梯度
#################################################################
# update actor
with marked_timer("update_actor", timing_raw, "red"):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics) # Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, "green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint() with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile # collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0 # TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps) if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return progress_bar.update(1)
self.global_steps += 1
self.gen_steps += 1
# check if last step checkpint exists
checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
if not os.path.exists(checkpoint_dir):
# save last step checkpoint
timing_raw = defaultdict(float)
with marked_timer("save_checkpoint", timing_raw, "green"):
self._save_checkpoint()
metrics = {f"timing/{k}": v for k, v in timing_raw.items()}
logger.log(data=metrics, step=self.global_steps)

这时候咱们再看一下dapo的reward manager实现:主要和ppo的区别在于使用了overlong_buffer,计算长度的reward

verl/verl/workers/reward_manager/dapo.py

#################################################################
# 这里使用dapo注册了DAPORewardManager,因此可以用
# reward_manager_cls = get_reward_manager_cls(reward_manager_name)得到
#################################################################
@register("dapo")
class DAPORewardManager(AbstractRewardManager):
"""The reward manager.""" def __init__(
self,
tokenizer,
num_examine,
compute_score=None,
reward_fn_key="data_source",
max_resp_len=None,
overlong_buffer_cfg=None,
) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key
self.overlong_buffer_cfg = overlong_buffer_cfg
self.max_resp_len = max_resp_len if self.overlong_buffer_cfg is not None:
assert self.max_resp_len is not None, (
f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"
)
assert self.max_resp_len >= self.overlong_buffer_cfg.len, (
"max_resp_len must be larger than overlong_buffer.len"
) #################################################################
# DAPO reward manager的主要函数
#################################################################
def __call__(self, data: DataProto, return_dict: bool = False):
"""We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"] reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_extra_info = defaultdict(list) already_print_data_sources = {} for i in range(len(data)):
data_item = data[i] # DataProtoItem prompt_ids = data_item.batch["prompts"] prompt_length = prompt_ids.shape[-1] ########################################################
# 值得注意的是。prompt_ids是左填充的
# response_ids是右填充的
########################################################
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length] # decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
eos_token = self.tokenizer.eos_token
if response_str.endswith(eos_token):
response_str = response_str[: -len(eos_token)] ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get("extra_info", {}) rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {}) extra_info["rollout_reward_scores"] = rollout_reward_scores result = self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
) score: float
if isinstance(result, dict):
score = result["score"]
# Store the information including original reward
for key, value in result.items():
reward_extra_info[key].append(value)
else:
score = result
reward_extra_info["acc"].append(score) reward = score ########################################################
# 这里是overlong reward的计算
########################################################
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
reward += overlong_reward
if self.overlong_buffer_cfg.log:
reward_extra_info["overlong_reward"].append(overlong_reward)
reward_extra_info["overlong"].append(overlong_reward < 0) reward_tensor[i, valid_response_length - 1] = reward if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0 if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
if isinstance(result, dict):
for key, value in result.items():
print(f"[{key}]", value)
else:
print("[score]", score) if return_dict:
return {
"reward_tensor": reward_tensor,
"reward_extra_info": reward_extra_info,
}
else:
return reward_tensor

dapo和ppo的具体区别可进一步参考:dapo readme

DAPO代码实现浅析的更多相关文章

  1. 微信小程序「官方示例代码」浅析【上】

    从某个微信群里,拿到了这个IDE的下载地址,然后就有了这个: 根本登不上去,怎么办,怎么办呢? 看代码啊... 反正我又没有保密协议,解压缩一看NodeWebkit + React: 好啦 ,逛逛呗, ...

  2. gtest代码库浅析

    代码库工程概述 IDE:Visual Studio 2010 sln路径:gtest\msvc\gtest.sln 用IDE打开上面的sln,可以看到以下四个工程,算不上复杂.展开之后更是感觉这几个工 ...

  3. 通过PHP与Python代码对比浅析语法差异

    一.背景 人工智能这几年一直都比较火,笔者一直想去学习一番:因为一直是从事PHP开发工作,对于Python接触并不算多,总是在关键时候面临着基础不牢,地动山摇的尴尬,比如在遇到稍微深入些的问题时候就容 ...

  4. [Js代码风格]浅析模块模式

    1.实例解释模块模式 简明扼要的说,经典的模块模式指的定义一个立即执行的匿名函数.在函数中定义私有函数和私有变量并且返回一个包含公共变量和公共函数作为属性和方法的匿名对象. var classicMo ...

  5. fir.im Weekly - 除了写代码,还需要了解什么

    雾霾天,宜撸代码.吹牛,不宜出门约会(¬_¬)ノ 本期 fir.im Weekly 亦如往期,收集了优秀的  iOS/Android 开发资源,GitHub 源码.前端方面的热点分享.除了代码,也许你 ...

  6. Linux下ffmpeg添加Facebook/transform代码块实现将全景视频的球模型转换成立方体模型

    Facebook事实上已开始在平台中支持360度全景视频的流播,但公司对此并不满足.其工程师更是基于锥体几何学设计出了一套全新的视频编码,号称最高能将全景视频的文件大小减少80%.(VR最新突破:全景 ...

  7. 微信小程序(应用号)资源汇总整理

    微信小应用资源汇总整理 开源项目 WeApp - 微信小程序版的微信 wechat-weapp-redux-todos - 微信小程序集成Redux实现的Todo list wechat-weapp- ...

  8. 微信小程序(应用号)开发资源汇总整理 - 一直更新中

    开源项目 wechat-weapp-gank - 微信小程序版Gank客户端 wechat-dribbble - 微信小程序-Dribbble wechatApp-demo - 微信小程序 DEMO ...

  9. 微信小程序一步步搭建商城系列-01-开篇

    1.小程序介绍 小程序是一种不需要下载安装即可使用的应用,它实现了应用“触手可及”的梦想,用户扫一扫或者搜一下即可打开应用.也体现了“用完即走”的理念,用户不用关心是否安装太多应用的问题.应用将无处不 ...

  10. 微信小程序内测申请

    想申请微信小程序的内测?别做梦了! 小程序内测是邀请制的,目前就发放了200个内测邀请.正因为稀缺,江湖传言内测资格已经炒到300万(一套房)一个了 但是!!!!你可以先熟悉一下相关资料和文档,下载一 ...

随机推荐

  1. flink 1.11.2 学习笔记(2)-Source/Transform/Sink

    一.flink处理的主要过程 从上一节wordcount的示例可以看到,flink的处理过程分为下面3个步骤: 1.1 .添加数据源addSource,这里的数据源可以是文件,网络数据流,MQ,Mys ...

  2. Win11专业版隐藏文件取消隐藏的问题

    有电脑基地的用户在使用win11系统时,单独设置某个被隐藏起来的文件隐藏属性,文件就不会在资源管理器中显示出来.但是win11专业版中隐藏文件怎么取消隐藏,下面给大家带来详细步骤,有需要的快来看看.方 ...

  3. java发送短信验证码DEMO 发送POST请求示例

    package apiserver; import org.apache.commons.httpclient.DefaultHttpMethodRetryHandler; import org.ap ...

  4. ECS动画插件

    ECS动画插件   Rukhanka Animation System 2 https://assetstore.unity.com/packages/tools/animation/rukhanka ...

  5. 【AirSim】图像API的使用

    博客地址:https://www.cnblogs.com/zylyehuo/ 参考链接: [AirSim] 具体效果可以关注博主的小红书: 42891122102,上面有效果视频 一.基本信息与Air ...

  6. 转-OAuth2.0 原理流程及其单点登录和权限控制

    原文链接 单点登录是多域名企业站点流行的登录方式.本文以现实生活场景辅助理解,力争彻底理清 OAuth2.0 实现单点登录的原理流程.同时总结了权限控制的实现方案,及其在微服务架构中的应用. 1 什么 ...

  7. Alexa实时对话翻译技术解析

    技术架构概述 某中心推出的Alexa实时翻译功能支持两种不同语言的用户进行对话,由系统自动识别语言并实时翻译.该功能基于现有技术栈构建,包括: 自动语音识别(ASR)系统 某机构Translate神经 ...

  8. LangGraph官方文档笔记——1.构建一个基础聊天机器人

    目录 LangGraph介绍 官方资料 环境配置 看官方的示例 快速开始demo--基础聊天机器人 1.设置环境变量 2.创建State类 3.创建langgraph图 4.创建LLM和聊天chatb ...

  9. [题解]POJ3304 Segment

    POJ3304 Segment 题意简述 多测,每次给定\(n(n\le 100)\)条线段,请问是否能找到\(1\)条直线,使得所有线段在该直线上的投影有公共部分. 注:两点距离\(<10^{ ...

  10. LiveGBS流媒体服务如何开启GB28181转RTSP输出

    GB28181流媒体服务搭建 搭建入口,解压启动即用:https://www.liveqing.com/docs/download/LiveGBS.html 配置RTSP端口 注意 在 Linux 下 ...