强化学习框架:OpenRLHF源码解读,模型处理
强化学习框架:OpenRLHF源码解读,模型处理
本文主要介绍 强化学习框架:OpenRLHF源码解读,模型处理
models框架设计
了解一下 OpenRLHF的模型框架设计范式:

可以知道一个大概的流程:输入Pormpt通过Actor model输出回复 Response,而后将两部分进行拼接再去由其他模型进行处理
1、actor.py
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/actor.py
这部分主要为加载所需要的模型
class Actor(nn.Module):
def __init__(...):
if isinstance(pretrain_or_model, str):
...
self.model = model_class.from_pretrained(
pretrain_or_model,
trust_remote_code=True,
attn_implementation=attn_implementation,
quantization_config=nf4_config,
torch_dtype=torch.bfloat16 if bf16 else "auto",
device_map=device_map,
)
if lora_rank > 0:
self.model.enable_input_require_grads()
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
)
self.model = get_peft_model(self.model, lora_config)
...
else:
self.model = pretrain_or_model
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, **kwargs):
...
sequences = self.model.generate(**generate_args)
eos_token_id = generate_args["eos_token_id"]
pad_token_id = generate_args["pad_token_id"]
return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id)
def forward(...):
...
output["logits"] = output["logits"].to(torch.float32) # 得到每一个token概率
...
log_probs = log_probs_from_logits(
output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature
)
...
action_log_probs = log_probs[:, -num_actions:]
这个actor比较简单,首先从huggingface加载需要的模型,并且对模型进行部分设置如:量化/lora微调。或者直接加载自己预训练好的模型。
1、generate:模块则是根据输入的内容(比如说被 tokenizer处理好的文本)input_ids通过模型输出新的内容(根据 **kwargs获取生成文本参数设置比如说:top_k等)
2、forward:根据输入的 token 序列(sequences),计算模型在生成最后若干个 token(即 "动作")时的对数概率(log probs),之所以要这么处理是因为,在强化学习模型中(PPO、DPO等)一般而言模型的输出是一个序列,但优化目标不是“能不能生成这个序列”,而是:这个序列中,哪些 token 是“好”的?模型对这些 token 的概率应该更高!比如说在 DPO中:
\]
里面的
\]
就是概率比值,上面代码中:
log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature)
计算的就是:\(log(\pi_{\theta}(a|s))\),在具体代码中:
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
if temperature != 1.0:
logits.div_(temperature)
if logits.dtype in [torch.float32, torch.float64]:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1))
log_probs_labels = -output[0].view(*batch_dim)
except ImportError:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim))
logsumexp_values = logsumexp_values.view(*batch_dim)
log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
log_probs_labels = []
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
row_log_probs = F.log_softmax(row_logits, dim=-1)
row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
log_probs_labels.append(row_log_probs_labels)
log_probs_labels = torch.stack(log_probs_labels)
return log_probs_labels
补充-1:
在使用AutoModelForCausalLM.from_pretrained使用得到model之后,其支持输入参数为:
outputs = model(
input_ids=None, # 输入的token(batch_size, seq_length)
attention_mask=None, # 指示哪些 token 是有效的(非 padding),形状同 input_ids
position_ids=None, # 位置编码
past_key_values=None,
inputs_embeds=None,
use_cache=None, # 是否使用k-v cache
labels=None, # 输入标签就直接计算loss
output_attentions=None,
output_hidden_states=None,
return_dict=None,
)
补充-2:
在LLM训练过程中遇到过短的语句为了节约显存(如果都将内容补充到相同长度,那么就会有较多的padding造成浪费),因此可以将几个短的拼接起来,但是为了区分那些是一个句子那些不是的,在 OpenRLHF中通过参数:self.packing_samples。如果没有packing那么直接根据attention_mask将位置编码在处理一下
if not self.packing_samples:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
# convert attention_mask to position_ids
if ring_attn_group is not None:
labels = sequences
sequences, attention_mask, position_ids = convert_ring_attn_params(
sequences, attention_mask, packed_seq_lens, ring_attn_group
)
else:
position_ids = reset_position_ids(attention_mask)
# explicitly ignore attention_mask for packing_samples
attention_mask = None
其中
reset_position_ids做的就是重新做位置编码重新处理
2、model.py
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/model.py

主要功能返回所需要的模型,主要返回2个模型:1、CriticModel;2、RewardModel 回顾一下这几类模型的作用:无论是在GRPO还是DPO中都会输出token然后需要去对token进行评分,起评分作用的就是 reward model 对应上面图中 reward model,除此之外都会计算 优势函数(\(Q(s,a)-V(s)\))来评估策略的好坏优势函数里面计算就是通过 critic model来对某一个策略进行评估对应上面图像中的:value model
def _get_reward_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):
class RewardModel(base_pretrained_model):
def __init__(...):
...
# 加载模型
setattr(self, self.base_model_prefix, base_llm_model(config))
self.value_head_prefix = value_head_prefix
setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False) # 输出评分
...
def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, return_output=False, ring_attn_group=None,pad_sequence=False, packed_seq_lens=None,):
...# 1、处理packing
outputs = getattr(self, self.base_model_prefix)(
input_ids, attention_mask=attention_mask, position_ids=position_ids
)
last_hidden_states = outputs["last_hidden_state"]
values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)
...# 1、处理packing
else:
# 输出最后一个有效token的评分代替整个句子评分
eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
reward = values.gather(dim=1, index=eos_indices).squeeze(1)
if not self.training and self.normalize_reward:
reward = (reward - self.mean) / self.std
return (reward, outputs) if return_output else reward
return RewardModel
def _get_critic_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):
class CriticModel(base_pretrained_model):
def __init__(...):
...
def forward(...):
...# 1、处理packing
outputs = getattr(self, self.base_model_prefix)(
input_ids, attention_mask=attention_mask, position_ids=position_ids
)
last_hidden_states = outputs["last_hidden_state"]
values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)
...
if num_actions is None:
assert return_output
return outputs
if not self.packing_samples:
action_values = values[:, -num_actions:]
else:
assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens)
action_values = []
offset = 0
for num_action, seq_len in zip(num_actions, packed_seq_lens):
start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1
action_values.append(values[:, start:end])
offset += seq_len
action_values = torch.cat(action_values, dim=1)
if return_output:
return (action_values, outputs)
else:
return action_values
return CriticModel
1、reward model: 传入一个 base_pretrained_model(比如 PreTrainedModel)、一个 base_llm_model(比如 AutoModel)以及一些控制参数。函数内部返回一个定制化的奖励模型类 RewardModel,它可以在给定输入句子时,输出一个数值(reward 分数),反映输出文本的质量。在forward计算中,直接将输入model使用的几个参数(见上面的补充有具体解释)计算最后取最后一个状态的值,并且将这个值取计算评分。也就是说 reward model:首先计算下一个预测的token而后对这些token进行打分
2、critic model:具体输入参数和 reward model相同。参考之前介绍,上面代码中直接返回action_values = values[:, -num_actions:]( num_actions存在条件下)这样就会得到不同的Q(s, a1), Q(s, a2), ...
总结上面两组模型,在 LLM 的强化学习场景下,Reward Model 和 Critic Model 都从 last_hidden_state 得到 token-level 表达,再用 Linear 层输出每个 token 的 score。
Reward Model最后提取的是 EOS token 的 score,表示整句话的奖励。Critic Model会进一步提取最后 num_actions 个 token 的 value,这些 token 是 Actor 生成的动作,对应到 PPO 中的:(,)=(,)−()。
理解上面内容,回顾最上面的框架设计,用下面例子进行解释。
Prompt:"The capital of France is"
Actor model:"Paris is beautiful"。那么合并得到:input_ids = ["The", "capital", "of", "France", "is", " Paris", "is", "beautiful"]
Reward model:对上面每个单词进行评分,假设:values = [0.1, 0.2, 0.3, 0.2, 0.4, 0.7, 0.5, 0.8] # 每个 token 的 score 而后输出句子中整体评分 0.8
Critic model:只对最后几个 token 的 action 计算 loss,于是:action_values = values[:, -3:] # 即取出最后 3 个生成 token 的 Q 值这些值也就对应了我们模型的生成
3、loss.py
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py
补充-1:
裁剪使用的是torch.clamp(https://pytorch.org/docs/stable/generated/torch.clamp.html)强制将范围外的数值处理为边界值,范围内数字保持不变
1、PolicyLoss:Policy Loss for PPO
r_t &= \exp(\log \pi(a_t \mid s_t) - \log \pi_{\text{old}}(a_t \mid s_t)) \\
\mathcal{L}_{\text{clip}}(t) &= \min\left(r_t \cdot A_t,\ \text{clip}(r_t,\ 1 - \epsilon,\ 1 + \epsilon) \cdot A_t\right) \\
\mathcal{L}_{\text{policy}} &= -\mathbb{E}_t \left[ \mathcal{L}_{\text{clip}}(t) \right]
\end{align*}
\]
2、ValueLoss: Value Loss for PPO
(V_{\text{clip}, t} - R_t)^2, \, (V_t - R_t)^2
\right) \right]\\
\text{其中:}V_{\text{clip}} = V_{\text{old}} + \text{clip}(V - V_{\text{old}}, -\epsilon, \epsilon)
\]
代码测试
修改了代码见链接:https://www.big-yellow-j.top/_jupyter/OpenRLHF_model.py
总结
本文主要介绍了在 OpenRLHF中模型框架设计,主要分为3类模型:1、actor model;2、critic model;3、reward model这三类模型中分别起到作用:1、直接更具prompt输出response;2、输出token的评分(action_values = values[:, -3:]);3、返回整句输出评分(找出最后一个有效 token 的索引,然后从 value 向量中提取该位置的值作为 reward。)
强化学习框架:OpenRLHF源码解读,模型处理的更多相关文章
- SDWebImage源码解读 之 NSData+ImageContentType
第一篇 前言 从今天开始,我将开启一段源码解读的旅途了.在这里先暂时不透露具体解读的源码到底是哪些?因为也可能随着解读的进行会更改计划.但能够肯定的是,这一系列之中肯定会有Swift版本的代码. 说说 ...
- seajs 源码解读
之前面试时老问一个问题seajs 是怎么加载js 文件的 在网上找一些资料,觉得这个写的不错就转载了,记录一下,也学习一下 seajs 源码解读 seajs 简单介绍 seajs是前端应用模块化开发的 ...
- 基于Docker的TensorFlow机器学习框架搭建和实例源码解读
概述:基于Docker的TensorFlow机器学习框架搭建和实例源码解读,TensorFlow作为最火热的机器学习框架之一,Docker是的容器,可以很好的结合起来,为机器学习或者科研人员提供便捷的 ...
- Restful 1 -- REST、DRF(View源码解读、APIView源码解读)及框架实现
一.REST 1.什么是编程? 数据结构和算法的结合 2.什么是REST? - url用来唯一定位资源,http请求方式来区分用户行为 首先回顾我们曾经做过的图书管理系统,我们是这样设计url的,如下 ...
- Bert系列(二)——源码解读之模型主体
本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...
- etcd学习(6)-etcd实现raft源码解读
etcd中raft实现源码解读 前言 raft实现 看下etcd中的raftexample newRaftNode startRaft serveChannels 领导者选举 启动并初始化node节点 ...
- Spark学习之路 (十六)SparkCore的源码解读(二)spark-submit提交脚本
一.概述 上一篇主要是介绍了spark启动的一些脚本,这篇主要分析一下Spark源码中提交任务脚本的处理逻辑,从spark-submit一步步深入进去看看任务提交的整体流程,首先看一下整体的流程概要图 ...
- Prometheus 源码解读(一)
Prometheus 源码解读(一) Prometheus 是云原生监控领域的事实标准,越来越来的开源项目开始支持 Prometheus 监控数据格式.从本篇开始,我将和大家一起阅读分析 Promet ...
- AFNetworking 3.0 源码解读 总结(干货)(上)
养成记笔记的习惯,对于一个软件工程师来说,我觉得很重要.记得在知乎上看到过一个问题,说是人类最大的缺点是什么?我个人觉得记忆算是一个缺点.它就像时间一样,会自己消散. 前言 终于写完了 AFNetwo ...
- AFNetworking 3.0 源码解读(三)之 AFURLRequestSerialization
这篇就讲到了跟请求相关的类了 关于AFNetworking 3.0 源码解读 的文章篇幅都会很长,因为不仅仅要把代码进行详细的的解释,还会大概讲解和代码相关的知识点. 上半篇: URI编码的知识 关于 ...
随机推荐
- Flink批处理-简单案例-01
一.简单案例 <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http ...
- Map模糊搜索key
一.代码 public class StringTest { public static void main(String[] args) { Map<String, Object>map ...
- window本地部署deepseek
window本地部署deepseek 学习自[[教程]DeepSeek本地免费部署教程,丝滑不卡顿!带你解锁隐藏功能!]https://www.bilibili.com/video/BV1viFaeB ...
- Java8 stream 提取对象 List 中的某一字段生成新的 List
//输出List StudentInfo.printStudents(studentList); //从对象列表中提取一列(以name为例) List<String> nameList = ...
- 使用mybatis-plus转换枚举值
1. 使用mybatis-plus转换枚举值 枚举值转换方式有很多,有以下方式: 后端写一个通用方法,只要前端传枚举类型,后端返回相应的枚举值前端去匹配 优点:能够实时保持数据一致性 缺点:如果有大量 ...
- Linux用户登录失败锁定策略
1.账户锁定策略介绍 在Linux系统中,为了提高系统安全性,防止暴力破解攻击,我们可以通过配置PAM(Pluggable Authentication Modules)模块来限制登录失败次数并锁定用 ...
- FUSE,从内核到用户态文件系统的设计之路
FUSE(Filesystem in Userspace)是一个允许用户在用户态创建自定义文件系统的接口,诞生于 2001 年.FUSE 的出现大大降低了文件系统开发的门槛,使得开发者能够在不修改内核 ...
- php禁止跨域调用api(来自文心快码)
在PHP中,禁止跨域调用API通常涉及到设置正确的HTTP响应头,以告知浏览器不允许来自不同源的请求.跨域资源共享(CORS)是一个W3C标准,它允许服务器放宽同源策略(SOP),从而允许某些跨站请求 ...
- Git pull(拉取),push(上传)命令整理(详细)
转自:https://www.cnblogs.com/wbl001/p/11495110.html (文档较长,请大家耐心阅读,很有帮助) git比较本地仓库和远程仓库的差异 更新本地的远程分支 gi ...
- Docker镜像的内部机制
Docker镜像的内部机制 镜像就是一个打包文件,里面包含了应用程序还有它运行所依赖的环境,例如文件系统.环境变量.配置参数等等. 环境变量.配置参数这些东西还是比较简单的,随便用一个 manifes ...