概述

首发自个人公众号:阿郎小哥的随笔驿站

DeepSeek R1系列建议阅读之前的系列文章:

聊聊DeepSeek R1的一些总结

聊聊DeepSeek R1的开源复现库——Open R1之合成数据

聊聊DeepSeek R1的知识蒸馏与应用思考

简介

GRPO 是一种在线学习算法,这意味着它通过在训练期间使用受训模型自身生成的数据来迭代改进。GRPO 目标背后的直觉是最大化生成补全的优势,同时确保模型保持接近参考策略。

GRPO 的四个主要步骤:生成补全计算优势估计 KL 散度计算损失

与传统的RL方法不同,后者通常依赖外部评估者(批评者)来引导学习,GRPO通过评估一组响应之间的相对关系来优化模型。这种方法提高了训练效率,使GRPO在需要复杂问题解决和长链思维的推理任务中表现尤为出色。

步骤分解

步骤1:选择查询

• 从训练数据集$ P(Q) $中选择一个查询$ (q) $。

• 示例:假设查询是“8 + 5的和是多少?”

步骤2:生成一组响应

• 模型针对该查询生成一组$ G $个响应。

• 示例:模型生成以下响应:

• o1:“答案是13。”

• o2:“十三。”

• o3:“是12。”

• o4:“和是13。”

步骤3:计算每个响应的奖励

• 什么是奖励?奖励通过量化响应的质量来引导模型的学习。

• GRPO中的奖励类型:

• 准确性奖励:基于响应的正确性(例如,解答数学题)。

• 格式奖励:确保响应符合结构化要求(例如,推理过程需要包含在标签中)。

• 语言一致性奖励:惩罚语言混杂或格式不一致的响应。

• 根据每个响应的好坏,赋予一个奖励($ r_i $)。

例如,奖励可能取决于:

• 准确性:答案是否正确?

• 格式:响应是否结构良好?

示例:

• r1 = 1.0(正确且格式良好)

• r2 = 0.9(正确但较不正式)

• r3 = 0.0(错误答案)

• r4 = 1.0(正确且格式良好)

步骤4:比较响应(群体优势)

• 计算每个响应相对于群体的优势$ (A_i) $,paper中相关术语如下:

用简单的方式理解,就是这样:

• 比较结果优于群体平均水平的响应会获得正分,而表现较差的响应会得到负分。

• 这种方式在群体内部激发竞争,推动模型生成更好的响应。

步骤5:使用裁剪更新策略

示例:如果新策略开始给o1分配过高的概率,裁剪机制确保不会过度强调这个响应。

这种方式保证了即使在像推理这样复杂的任务中,策略优化也能保持稳定和可靠。

步骤6:通过KL散度惩罚偏差

GRPO实现

Open R1

在Open R1的复现路径中

实现了基于GRPO算法的训练,脚本如下

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yaml

confg_full.yaml

# 基座模型
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16 # 训练数据集
dataset_name: AI-MO/NuminaMath-TIR
dataset_configs:
- all
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7 # GRPO训练器参数
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 10
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 1
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1

Open R1提供了grpo算法的实现——grpo.py,删减了部分无关代码,关键的程序逻辑如下:

@dataclass
class GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
) def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward) return rewards def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
} SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
) def main(script_args, training_args, model_args): # Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] # Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
} dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages") logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs #############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
) ###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state() ##################################
# Save model and create model card
##################################
trainer.save_model(training_args.output_dir) # Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir) ##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) #############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs) if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

代码分析如下:

首先就是加载数据集,但数据集在加载时,会有指定的提示词,即代码中的make_conversation函数,该函数构造指定的prompt引导模型的输出,格式如下:

{
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}

对于SYSTEM_PROMPT,描述如下:

"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"

总的来说就是,引导模型先思考推理过程,再按格式将推理过程与回复放入指定标签<think>、<answer>内。

接下来是reward函数,grpo算法有两种奖励:准确性奖励与格式正确奖励;如下

def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward) return rewards def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches] reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}

最后就是训练,GRPOTrainer是transformers库提供的基于Trainer的训练类,传入指定的参数即可实现基于GRPO算法的实现;其中比较关键的是reward、train_dataset。

#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)

计算训练的checkpoint与循环周期,则会在Trainer类中通过gradient_accumulation_steps(梯度累积步数)、num_train_epochs(训练轮数)以及 per_device_train_batch_size(每个设备的训练批次大小)这些参数计算训练周期。

###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)

小结

总的来说,Open R1的GRPO训练,是基于GRPOTrainer指定prompt/datasetreward等参数实现GRPO的训练。也就是说,在指定的训练数据集下,通过prompt引导模型的输出,然后基于grpo算法及其reward对 模型的输出与训练数据集的output 做奖惩打分(通过KL散度比较),计算loss,再反向传播。循环反复;最终完成模型的RL训练,达到让模型能做到CoT式的回复,即生成补全计算优势估计 KL 散度计算损失的步骤,如最开始的图所示。

对于GRPOTrainer类的源码及文档可参考:

首发自个人公众号:阿郎小哥的随笔驿站

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型的更多相关文章

  1. EM算法(2):GMM训练算法

    目录 EM算法(1):K-means 算法 EM算法(2):GMM训练算法 EM算法(3):EM算法运用 EM算法(4):EM算法证明 EM算法(2):GMM训练算法 1. 简介 GMM模型全称为Ga ...

  2. 又一重要进展发布!OpenMMLab算法仓支持昇腾AI训练加速

    摘要:上海人工智能实验室的浦视开源算法体系(OpenMMLab)团队基于昇腾AI发布了MMDeploy 0.10.0版本,该版本已支持OpenMMLab算法仓库在昇腾异构计算架构CANN上的推理部署. ...

  3. 聊聊dmClock算法

    作者:吴香伟 发表于 2017/01/08 版权声明:可以任意转载,转载时务必以超链接形式标明文章原始出处和作者信息以及版权声明 人们常常容易忽略一些不起眼但特别重要的事物.曾经跟同事聊Python, ...

  4. 一个基于特征向量的近似网页去重算法——term用SVM人工提取训练,基于term的特征向量,倒排索引查询相似文档,同时利用cos计算相似度

    摘  要  在搜索引擎的检索结果页面中,用户经常会得到内容相似的重复页面,它们中大多是由于网站之间转载造成的.为提高检索效率和用户满意度,提出一种基于特征向量的大规模中文近似网页检测算法DDW(Det ...

  5. 聊聊同步、异步、阻塞、非阻塞以及IO模型

    前言 在使用Netty改造手写RPC框架的时候,需要给大家介绍一些相关的知识,这样很多东西大家就可以看明白了,手写RPC是一个支线任务,后续重点仍然是Kubernetes相关内容. 阻塞与非阻塞 同步 ...

  6. 基于AdaBoost算法——世纪晟结合Haar-like特征训练人脸检测识别

      AdaBoost 算法是一种快速人脸检测算法,它将根据弱学习的反馈,适应性地调整假设的错误率,使在效率不降低的情况下,检测正确率得到了很大的提高.   系统在技术上的三个贡献: 1.用简单的Haa ...

  7. 推荐算法之用矩阵分解做协调过滤——LFM模型

    隐语义模型(Latent factor model,以下简称LFM),是推荐系统领域上广泛使用的算法.它将矩阵分解应用于推荐算法推到了新的高度,在推荐算法历史上留下了光辉灿烂的一笔.本文将对 LFM ...

  8. Tensorflow r1.12及tensorflow serving r1.12 GPU版本编译遇到的问题

    1.git clone tensorflow serving 及tensorflow代码 2. ERROR: /root/.cache/bazel/_bazel_root/f71d782da17fd8 ...

  9. PHP微信红包生成算法的程序源码(用抛物线的模型实现)

    代码如下: <?php /* * 红包生成随机算法 */ header("Content-type:text/html;charset=utf-8"); date_defau ...

  10. 1.K近邻算法

    (一)K近邻算法基础 K近邻(KNN)算法优点 思想极度简单 应用数学知识少(近乎为0) 效果好 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 图解K近邻算法 上图是以 ...

随机推荐

  1. element-ui resetFields 无效的问题

    1.问题 触发bug的条件是先打开,编辑进行赋值,后打开新增 先点开编辑 再打开新增 这个时候你会发现刚刚赋值过的数据还遗留在表单里面 即使在打开  dialog  的时候执行了重置也没有效果 res ...

  2. Lucene 源代码剖析-2 Lucene是什么

    转载自 http://download.csdn.net/source/858994 源地址下是 Word 文档,这里转换成HTML 格式 1           Lucene是什么 Apache L ...

  3. 创建一个具有商业品质的 Eclipse IDE

    创建具有商业品质且可插入 Eclipse 的专业 IDE Prashant Deva (pdeva@placidsystems.com), 创始人, Placid Systems 简介:  " ...

  4. 七、FreeRTOS学习笔记-中断管理

    FreeRTOS学习笔记-中断管理 中断:让CPU打断正常运行的程序,转而去处理紧急的事件(程序) 中断执行机制,可简单概括为三步: 1.中断请求:外设产生中断请求(GPIO外部中断.定时器中断等) ...

  5. golang之浮点数处理库decimal

    decimal库包是用来解决float类型对象之间运算不准确的问题的.所以,如果你想使用decimal库包,你必须先把float类型对象通过decimal.NewFromFloat()函数转成deci ...

  6. PL/SQL中文乱码修正

    我根据需求,,需要修改 数据库的部分表格的部分字段,然而在Update的时候,出现了中文乱码(Type字段). 此时,我用的是客户端,服务器没有安装,在另一台机器上,所以,我需要做的是修改客户端编码: ...

  7. 前后端数据传递之form-data

    前情 最近在项目开发中,跟服务端连调发现接口一直报错,服务端一直提示是数据没有传,而通过浏览器控制台发现数据是有传的. 坑 服务通过postman自测是OK的.经过和服务端一起定位发现服务端只接收以f ...

  8. 第36次ccf-csp题解(思维)

    比赛链接 https://sim.csp.thusaac.com/contest/36/home   比赛感受 这会刚打完上海icpc,比起区域赛的题,这个简单太多了. 感受还不错,写的很顺手.除了第 ...

  9. Collector的配置和使用

    目录 Collector的配置和使用 Collector配置 Receivers Processors Exporters Service Extensions 使用环境变量 Collector的使用 ...

  10. 【转载】Apache Doris、DorisDB傻傻分不清。。。

    https://www.sohu.com/a/488816742_827544   相信这两天很多社区小伙伴都看到 StarRocks 所谓"开源"的动态了,开源用户群里有很多小伙 ...