DPO: Direct Preference Optimization 直接偏好优化(学习笔记)
学习参考:链接1
一、为什么要提出DPO
在之前,我们已经了解到基于人类反馈的强化学习RLHF分为三个阶段:全监督微调(SFT)、奖励模型(RM)、强化学习(PPO)。但是RLHF面临缺陷:RLHF 是一个复杂且经常不稳定的过程,首先拟合反映人类偏好的奖励模型,然后使用强化学习微调大型无监督 LM,以最大化这种估计奖励,而不会偏离原始模型太远。为解决这一问题,提出一个直接偏好优化 (DPO) 的新算法:通过利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化,本质上是在人类偏好数据上解决一个分类问题。DPO是稳定的、性能和计算成本轻量级的,无需拟合奖励模型,在微调期间从 LM 中采样,或执行显着的超参数调整。通过实验表明:DPO 进行微调超过了 RLHF 效果,并提高了摘要和单轮对话的响应质量。

二、什么是DPO
DPO,一种基于人类偏好优化语言模型的新方法。与RLHF不同,DPO不依赖于明确的奖励建模或强化学习。它针对与RLHF相同的目标,但提供了一种更简单、更直接的培训方法。
DPO的工作原理:增加偏好样本的对数概率与减小非偏好样本响应的对数概率。它结合了动态加权机制,以避免仅使用概率比目标时遇到的模型退化问题。
DPO依赖于理论上的偏好模型,如Bradley-Terry模型,来测量奖励函数与经验偏好数据的对齐程度。与传统方法不同,传统方法使用偏好模型来训练奖励模型,然后基于该奖励模型训练策略,DPO直接根据策略定义偏好损失。给定一个关于模型响应的人类偏好数据集,DPO可以使用简单的二元交叉熵目标来优化策略,无需在训练过程中明确学习奖励函数或从策略中采样。具体推导见链接1
(1)原RLHF的优化目标:最大化奖励和最小化参考策略的KL散度

(2)DPO优化目标:利用了从奖励函数到最优策略的解析映射,允许直接使用人类偏好数据进行简化的优化过程

该目标增加了对偏好数据$y_w$的可能性,并减少了非偏好数据$y_l$的可能性。这些示例按照隐式奖励模型的评级加权,由$\beta$缩放.
DPO重参数化等效于具有隐式奖励函数:

参数模型$\pi_{\theta}$的优化等效于在此变量更改下的奖励模型优化。
(3)DPO在干什么?
为了从原理上理解 DPO,分析损失函数的梯度$L_{DPO} $。 相对于参数 θ 的梯度可以写为:

其中
是由语言模型$\pi_{\theta}$和参考模型$\pi_{ref}$隐式定义的奖励函数。直观上,损失函数 $L_{DPO} $的梯度增加了偏好$y_w$ 的可能性,并降低了非偏好$y_l$的可能性。更重要的是,样例的权重是通过: 隐式奖励模型$\hat{r}_{\theta}$对非偏好的评分高多少来衡量的,即$\hat{r}_{\theta}(x,y_l)-\hat{r}_{\theta}(x,y_w)$,按 β 进行缩放,即隐式奖励模型认为策略模型错误的程度。 我们的实验表明了这种加权的重要性,因为没有加权系数的这种方法的简单版本可能会导致语言模型退化。
(4)DPO outline

步骤1)是在构造数据集,通过对同一问题的两种回复的倾向性:chosen or rejected,反映人类偏好。
步骤2)在于优化,具体过程大概是,对于同一个question prompt,模型在两种模型:language/policy model 和 reference model下分别生成,对应chosen 和 rejected label真值标签的生成概率,因此可以获得四种概率值:policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, 用于DPO loss计算。
1、DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。
数据集由三部分组成:
promptchosenrejected
可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset

2、 预期模型格式
与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。
3、使用 DPOTrainer 源码
有关详细示例,请查看 Examples/scripts/dpo.py 脚本。 在较高级别上,我们需要使用我们希望训练的模型、参考 ref_model 来初始化 DPOTrainer,我们将使用它来计算首选和拒绝响应的隐式奖励,beta 指隐式奖励的超参数, 数据集包含上面列出的 3 个条目。 请注意,模型和 ref_model 需要具有相同的架构(即仅解码器或编码器-解码器)。
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
之后就可以调用:
dpo_trainer.train()
请注意,β 是 DPO 损失的温度参数,通常在 0.1 到 0.5 范围内。 当beta -> 0 ,意味着忽略参考模型。
4、损失函数
给定偏好数据,我们可以根据 Bradley-Terry 模型拟合二元分类器,事实上,DPO 作者通过 Logsigmoid 提出标准化似然的 sigmoid 损失来拟合逻辑回归。
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities. Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if reference_free:
ref_logratios = 0
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
) chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
) return losses, chosen_rewards, rejected_rewards
其他改进的损失函数:
RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。
5、指标:在训练和评估时,记录以下奖励指标:
rewards/chosen: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by betarewards/rejected: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by betarewards/accuracies: mean of how often the chosen rewards are > than the corresponding rejected rewardsrewards/margins: the mean difference between the chosen and corresponding rejected rewards
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {} (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics
DPO: Direct Preference Optimization 直接偏好优化(学习笔记)的更多相关文章
- KVM性能优化学习笔记
本学习笔记系列都是采用CentOS6.x操作系统,KVM虚拟机的管理也是采用virsh方式,网上的很多的文章都基于ubuntu高版本内核下,KVM的一些新的特性支持更好,本文只是记录了CentOS6. ...
- 深挖计算机基础:Linux性能优化学习笔记
参考极客时间专栏<Linux性能优化实战>学习笔记 一.CPU性能:13讲 Linux性能优化实战学习笔记:第二讲 Linux性能优化实战学习笔记:第三讲 Linux性能优化实战学习笔记: ...
- Pandas 性能优化 学习笔记
摘要 本文介绍了使用 Pandas 进行数据挖掘时常用的加速技巧. 实验环境 import numpy as np import pandas as pd print(np.__version__) ...
- mysql性能优化学习笔记(2)如何发现有问题的sql
一.使用mysql慢查询日志对有效率问题的sql进行监控 1)开启慢查询 show variables like ‘slow_query_log’;//查看是否开启慢查询日志 ...
- HIVE优化学习笔记
概述 之前写过关于hive的已经有两篇随笔了,但是作者依然还是一枚小白,现在把那些杂七杂八的总结一下,供以后查阅和总结.今天的文章介绍一下hive的优化.hive是好多公司都在使用的东西,也有好多大公 ...
- 燕十八MySQL优化学习笔记
观察 show status; 里面的这三个参数;Queries Threads_connected Threads_running判断周期性变化 -------------------------- ...
- mysql性能优化学习笔记-参数介绍及优化建议
MySQL服务器参数介绍 mysql参数介绍(客户端中执行),尽量只修改session级别的参数. 全局参数(新连接的session才会生效,原有已经连接的session不生效) set global ...
- mysql性能优化学习笔记
mysql性能优化 硬件对数据库的影响 CPU资源和可用内存大小 服务器硬件对mysql性能的影响 我们的应用是CPU密集型? 我们的应用的并发量如何? 数量比频率更好 64位使用32位的服务器版本 ...
- mysql优化学习笔记
优化sql的一般步骤 通过show status了解各种sql的执行频率 定位执行效率低的sql语句 通过explain分析效率低的sql 通过show profile分析sql 通过trace分析优 ...
- js性能优化--学习笔记
<高性能网站建设进阶指南>: 1.使用局部变量,避免深入作用域查找,局部变量是读写速度最快的:把函数中使用次数超过一次的对象属性和数组存储为局部变量是一个好方法:比如for循环中的.len ...
随机推荐
- 开发日志:Kylin麒麟操作系统部署ASP.NET CORE
需求场景: 我需要部署的项目是在Windows上开发的,目标框架为.net core 6.0 因此我们需要先在kylin上部署项目运行所需要的环境. 借助百度词条,先看看Kylin是什么: 服务器资源 ...
- 通过Google浏览器Cookie文件获取cookie信息,80以上版本有效
public class ReadCookie { /// <summary> /// </summary> /// <param name="hostName ...
- Oracle、达梦:数据库大小写不敏感,但是又要区分大小写敏感(默认敏感)
一. 艹,这个需求就很操蛋. 实现 SELECT * FROM T1 WHERE REGEXP_LIKE(field, '.*value.*', 'c'); 在 Oracle 数据库中使用 REGEX ...
- JavaWeb 中 “转发”与 “重定向”的区别
JavaWeb 中 "转发"与 "重定向"的区别 每博一文案 人生的常态,就是有聚有散,有得有失,就像山峰一样,总有高低,起伏不断. 曾经,我们是鲜衣怒马的少年 ...
- 用Vue仿了一个类似抖音的App
大家好,我是 Java陈序员. 今天,给大家介绍一个基于 Vue3 实现的高仿抖音开源项目. 关注微信公众号:[Java陈序员],获取开源项目分享.AI副业分享.超200本经典计算机电子书籍等. 项目 ...
- 检索增强生成(RAG)实践:基于LlamaIndex和Qwen1.5搭建智能问答系统
检索增强生成(RAG)实践:基于LlamaIndex和Qwen1.5搭建智能问答系统 什么是 RAG LLM 会产生误导性的 "幻觉",依赖的信息可能过时,处理特定知识时效率不高, ...
- geojson介绍和常用转换编辑工具
GeoJSON是一种基于JSON的地理空间数据交换格式,它定义了几种类型JSON对象以及它们组合在一起的方法,以表示有关地理要素.属性和它们的空间范围的数据. 2015年,互联网工程任务组(IETF) ...
- 基于uniapp+vue3自定义增强版table表格组件「兼容H5+小程序+App端」
vue3+uniapp多端自定义table组件|uniapp加强版综合表格组件 uv3-table:一款基于uniapp+vue3跨端自定义手机端增强版表格组件.支持固定表头/列.边框.斑马纹.单选/ ...
- handsontable多选下拉框编辑器扩展
一.效果截图 二.文件引用 多选下拉框扩展自handsontable的BaseEditor. 多选下拉框组件由两个文件构成, 一个下拉框样式表MultiSelect.css 一个组件实现脚本Multi ...
- linux用户管理:创建用户,删除用户,管理用户,用户配置
目录 一.关于用户 二.用户的三种类型 三.与用户有关的配置文件详解 四.创建用户 五.设置用户密码 六.删除用户 七.用户密码时效管理 八.查看用户相关信息的命令 九.修改用户基本信息 十.管理用户 ...