概述

首发自个人公众号:阿郎小哥的随笔驿站

DeepSeek R1系列建议阅读之前的系列文章:

聊聊DeepSeek R1的一些总结

聊聊DeepSeek R1的开源复现库——Open R1之合成数据

聊聊DeepSeek R1的知识蒸馏与应用思考

简介

GRPO 是一种在线学习算法,这意味着它通过在训练期间使用受训模型自身生成的数据来迭代改进。GRPO 目标背后的直觉是最大化生成补全的优势,同时确保模型保持接近参考策略。

GRPO 的四个主要步骤:生成补全计算优势估计 KL 散度计算损失

与传统的RL方法不同,后者通常依赖外部评估者(批评者)来引导学习,GRPO通过评估一组响应之间的相对关系来优化模型。这种方法提高了训练效率,使GRPO在需要复杂问题解决和长链思维的推理任务中表现尤为出色。

步骤分解

步骤1:选择查询

• 从训练数据集$ P(Q) $中选择一个查询$ (q) $。

• 示例:假设查询是“8 + 5的和是多少?”

步骤2:生成一组响应

• 模型针对该查询生成一组$ G $个响应。

• 示例:模型生成以下响应:

• o1:“答案是13。”

• o2:“十三。”

• o3:“是12。”

• o4:“和是13。”

步骤3:计算每个响应的奖励

• 什么是奖励?奖励通过量化响应的质量来引导模型的学习。

• GRPO中的奖励类型:

• 准确性奖励:基于响应的正确性(例如,解答数学题)。

• 格式奖励:确保响应符合结构化要求(例如,推理过程需要包含在标签中)。

• 语言一致性奖励:惩罚语言混杂或格式不一致的响应。

• 根据每个响应的好坏,赋予一个奖励($ r_i $)。

例如,奖励可能取决于:

• 准确性:答案是否正确?

• 格式:响应是否结构良好?

示例:

• r1 = 1.0(正确且格式良好)

• r2 = 0.9(正确但较不正式)

• r3 = 0.0(错误答案)

• r4 = 1.0(正确且格式良好)

步骤4:比较响应(群体优势)

• 计算每个响应相对于群体的优势$ (A_i) $,paper中相关术语如下:

用简单的方式理解,就是这样:

• 比较结果优于群体平均水平的响应会获得正分,而表现较差的响应会得到负分。

• 这种方式在群体内部激发竞争,推动模型生成更好的响应。

步骤5:使用裁剪更新策略

示例:如果新策略开始给o1分配过高的概率,裁剪机制确保不会过度强调这个响应。

这种方式保证了即使在像推理这样复杂的任务中,策略优化也能保持稳定和可靠。

步骤6:通过KL散度惩罚偏差

GRPO实现

Open R1

在Open R1的复现路径中

实现了基于GRPO算法的训练,脚本如下

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml

confg_full.yaml

# 基座模型
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16 # 训练数据集
dataset_name: AI-MO/NuminaMath-TIR
dataset_configs:
- all
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7 # GRPO训练器参数
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 10
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

Open R1提供了grpo算法的实现——grpo.py,删减了部分无关代码,关键的程序逻辑如下:

@dataclass
class GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
) def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward) return rewards def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
} SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
) def main(script_args, training_args, model_args): # Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] # Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
} dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages") logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs #############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
) ###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state() ##################################
# Save model and create model card
##################################
trainer.save_model(training_args.output_dir) # Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir) ##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) #############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs) if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

代码分析如下:

首先就是加载数据集,但数据集在加载时,会有指定的提示词,即代码中的make_conversation函数,该函数构造指定的prompt引导模型的输出,格式如下:

{
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}

对于SYSTEM_PROMPT,描述如下:

"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"

总的来说就是,引导模型先思考推理过程,再按格式将推理过程与回复放入指定标签<think>、<answer>内。

接下来是reward函数,grpo算法有两种奖励:准确性奖励与格式正确奖励;如下

def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward) return rewards def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}

最后就是训练,GRPOTrainer是transformers库提供的基于Trainer的训练类,传入指定的参数即可实现基于GRPO算法的实现;其中比较关键的是reward、train_dataset。

#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)

计算训练的checkpoint与循环周期,则会在Trainer类中通过gradient_accumulation_steps(梯度累积步数)、num_train_epochs(训练轮数)以及 per_device_train_batch_size(每个设备的训练批次大小)这些参数计算训练周期。

###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)

小结

总的来说,Open R1的GRPO训练,是基于GRPOTrainer指定prompt/datasetreward等参数实现GRPO的训练。也就是说,在指定的训练数据集下,通过prompt引导模型的输出,然后基于grpo算法及其reward对 模型的输出与训练数据集的output 做奖惩打分(通过KL散度比较),计算loss,再反向传播。循环反复;最终完成模型的RL训练,达到让模型能做到CoT式的回复,即生成补全计算优势估计 KL 散度计算损失的步骤,如最开始的图所示。

对于GRPOTrainer类的源码及文档可参考:

首发自个人公众号:阿郎小哥的随笔驿站

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型的更多相关文章

  1. EM算法(2):GMM训练算法

    目录 EM算法(1):K-means 算法 EM算法(2):GMM训练算法 EM算法(3):EM算法运用 EM算法(4):EM算法证明 EM算法(2):GMM训练算法 1. 简介 GMM模型全称为Ga ...

  2. 又一重要进展发布!OpenMMLab算法仓支持昇腾AI训练加速

    摘要:上海人工智能实验室的浦视开源算法体系(OpenMMLab)团队基于昇腾AI发布了MMDeploy 0.10.0版本,该版本已支持OpenMMLab算法仓库在昇腾异构计算架构CANN上的推理部署. ...

  3. 聊聊dmClock算法

    作者:吴香伟 发表于 2017/01/08 版权声明:可以任意转载,转载时务必以超链接形式标明文章原始出处和作者信息以及版权声明 人们常常容易忽略一些不起眼但特别重要的事物.曾经跟同事聊Python, ...

  4. 一个基于特征向量的近似网页去重算法——term用SVM人工提取训练,基于term的特征向量,倒排索引查询相似文档,同时利用cos计算相似度

    摘  要  在搜索引擎的检索结果页面中,用户经常会得到内容相似的重复页面,它们中大多是由于网站之间转载造成的.为提高检索效率和用户满意度,提出一种基于特征向量的大规模中文近似网页检测算法DDW(Det ...

  5. 聊聊同步、异步、阻塞、非阻塞以及IO模型

    前言 在使用Netty改造手写RPC框架的时候,需要给大家介绍一些相关的知识,这样很多东西大家就可以看明白了,手写RPC是一个支线任务,后续重点仍然是Kubernetes相关内容. 阻塞与非阻塞 同步 ...

  6. 基于AdaBoost算法——世纪晟结合Haar-like特征训练人脸检测识别

      AdaBoost 算法是一种快速人脸检测算法,它将根据弱学习的反馈,适应性地调整假设的错误率,使在效率不降低的情况下,检测正确率得到了很大的提高.   系统在技术上的三个贡献: 1.用简单的Haa ...

  7. 推荐算法之用矩阵分解做协调过滤——LFM模型

    隐语义模型(Latent factor model,以下简称LFM),是推荐系统领域上广泛使用的算法.它将矩阵分解应用于推荐算法推到了新的高度,在推荐算法历史上留下了光辉灿烂的一笔.本文将对 LFM ...

  8. Tensorflow r1.12及tensorflow serving r1.12 GPU版本编译遇到的问题

    1.git clone tensorflow serving 及tensorflow代码 2. ERROR: /root/.cache/bazel/_bazel_root/f71d782da17fd8 ...

  9. PHP微信红包生成算法的程序源码(用抛物线的模型实现)

    代码如下: <?php /* * 红包生成随机算法 */ header("Content-type:text/html;charset=utf-8"); date_defau ...

  10. 1.K近邻算法

    (一)K近邻算法基础 K近邻(KNN)算法优点 思想极度简单 应用数学知识少(近乎为0) 效果好 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 图解K近邻算法 上图是以 ...

随机推荐

  1. Windows下搭建Linux开发环境(vagrant)

    [下载] vagrant软件:https://www.virtualbox.org/wiki/Downloads centos镜像: http://isoredirect.centos.org/cen ...

  2. Adobe PS 2024 软件分享 torrent

    Adobe-Photoshop-2024-25.5.0.375 下载工具建议使用 qBittorrent-enhance,qBittorrent, Transmission, uTorrent 等. ...

  3. k8s 实战 3----副本集

    副本集是什么?我们在前文中讲过什么是pod,简单来说pod就是k8s直接操作的基本单位.不了解的同学可以参考前文: k8s 实战 1 ---- 初识 (https://www.cnblogs.com/ ...

  4. 二进制安装Kubernetes(k8s)v1.32.0

    二进制安装Kubernetes(k8s)v1.32.0 介绍 https://github.com/cby-chen/Kubernetes 开源不易,帮忙点个star,谢谢了 kubernetes(k ...

  5. Gitbook的docker安装配置

    创建目录:/gitbook/gitbook 和 /gitbook/html /gitbook/gitbook目录下,touch新建README.md docker安装gitbook docker ru ...

  6. 6.MySQL性能优化

    参数 作用范围 全局:对实例的所有会话起作用 会话级:只对当前会话起作用 set session binlog_rows_query_log_events = on; set global binlo ...

  7. vba interpreter 结束

    https://github.com/inshua/vba-interpreter 已覆盖几乎 VB 所有的特性,只是库还不够全. VB 语言自身较为落后,语法也有诸多设计不当.最严重的莫过于函数和数 ...

  8. 断言、drf之请求与响应

    目录 一.断言 二.drf之请求 2.1 Request能够解析的前端传入的编码格式 2.2 Request类有哪些属性和方法(学过) 常用参数 Response类的实例化参数 三.drf之响应 3. ...

  9. Shiro-BasicHttpAuthenticationFilter 鉴权过滤器的使用方式

    它的作用是用来根据路径匹配结果,调用相应过滤器 onPreHandle 这里是正在的执行逻辑,之前的都是判断,它返回了两个方法: isAccessAllowed() onAccessDenied() ...

  10. Hive数据库【操作】+ 【分区】+【分桶】+【查询】+【运算】+【函数】

    目录 键值对信息 数据库表操作 内部表操作 外部表操作 分区表的操作 分桶表操作 数据查询 关系运算 数学运算 逻辑运算 数值计算 日期函数 键值对信息 添加数据库的描述信息(添加键值对信息) cre ...