DPO: Direct Preference Optimization 直接偏好优化(学习笔记)
学习参考:链接1
一、为什么要提出DPO
在之前,我们已经了解到基于人类反馈的强化学习RLHF分为三个阶段:全监督微调(SFT)、奖励模型(RM)、强化学习(PPO)。但是RLHF面临缺陷:RLHF 是一个复杂且经常不稳定的过程,首先拟合反映人类偏好的奖励模型,然后使用强化学习微调大型无监督 LM,以最大化这种估计奖励,而不会偏离原始模型太远。为解决这一问题,提出一个直接偏好优化 (DPO) 的新算法:通过利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化,本质上是在人类偏好数据上解决一个分类问题。DPO是稳定的、性能和计算成本轻量级的,无需拟合奖励模型,在微调期间从 LM 中采样,或执行显着的超参数调整。通过实验表明:DPO 进行微调超过了 RLHF 效果,并提高了摘要和单轮对话的响应质量。
二、什么是DPO
DPO,一种基于人类偏好优化语言模型的新方法。与RLHF不同,DPO不依赖于明确的奖励建模或强化学习。它针对与RLHF相同的目标,但提供了一种更简单、更直接的培训方法。
DPO的工作原理:增加偏好样本的对数概率与减小非偏好样本响应的对数概率。它结合了动态加权机制,以避免仅使用概率比目标时遇到的模型退化问题。
DPO依赖于理论上的偏好模型,如Bradley-Terry模型,来测量奖励函数与经验偏好数据的对齐程度。与传统方法不同,传统方法使用偏好模型来训练奖励模型,然后基于该奖励模型训练策略,DPO直接根据策略定义偏好损失。给定一个关于模型响应的人类偏好数据集,DPO可以使用简单的二元交叉熵目标来优化策略,无需在训练过程中明确学习奖励函数或从策略中采样。具体推导见链接1
(1)原RLHF的优化目标:最大化奖励和最小化参考策略的KL散度
(2)DPO优化目标:利用了从奖励函数到最优策略的解析映射,允许直接使用人类偏好数据进行简化的优化过程
该目标增加了对偏好数据$y_w$的可能性,并减少了非偏好数据$y_l$的可能性。这些示例按照隐式奖励模型的评级加权,由$\beta$缩放.
DPO重参数化等效于具有隐式奖励函数:
参数模型$\pi_{\theta}$的优化等效于在此变量更改下的奖励模型优化。
(3)DPO在干什么?
为了从原理上理解 DPO,分析损失函数的梯度$L_{DPO} $。 相对于参数 θ 的梯度可以写为:
其中是由语言模型$\pi_{\theta}$和参考模型$\pi_{ref}$隐式定义的奖励函数。直观上,损失函数 $L_{DPO} $的梯度增加了偏好$y_w$ 的可能性,并降低了非偏好$y_l$的可能性。更重要的是,样例的权重是通过: 隐式奖励模型$\hat{r}_{\theta}$对非偏好的评分高多少来衡量的,即$\hat{r}_{\theta}(x,y_l)-\hat{r}_{\theta}(x,y_w)$,按 β 进行缩放,即隐式奖励模型认为策略模型错误的程度。 我们的实验表明了这种加权的重要性,因为没有加权系数的这种方法的简单版本可能会导致语言模型退化。
(4)DPO outline
步骤1)是在构造数据集,通过对同一问题的两种回复的倾向性:chosen or rejected,反映人类偏好。
步骤2)在于优化,具体过程大概是,对于同一个question prompt,模型在两种模型:language/policy model 和 reference model下分别生成,对应chosen 和 rejected label真值标签的生成概率,因此可以获得四种概率值:policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, 用于DPO loss计算。
1、DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。
数据集由三部分组成:
prompt
chosen
rejected
可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf
dataset
2、 预期模型格式
与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。
3、使用 DPOTrainer 源码
有关详细示例,请查看 Examples/scripts/dpo.py 脚本。 在较高级别上,我们需要使用我们希望训练的模型、参考 ref_model 来初始化 DPOTrainer,我们将使用它来计算首选和拒绝响应的隐式奖励,beta 指隐式奖励的超参数, 数据集包含上面列出的 3 个条目。 请注意,模型和 ref_model 需要具有相同的架构(即仅解码器或编码器-解码器)。
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
之后就可以调用:
dpo_trainer.train()
请注意,β 是 DPO 损失的温度参数,通常在 0.1 到 0.5 范围内。 当beta -> 0 ,意味着忽略参考模型。
4、损失函数
给定偏好数据,我们可以根据 Bradley-Terry 模型拟合二元分类器,事实上,DPO 作者通过 Logsigmoid 提出标准化似然的 sigmoid 损失来拟合逻辑回归。
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities. Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if reference_free:
ref_logratios = 0
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
) chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
) return losses, chosen_rewards, rejected_rewards
其他改进的损失函数:
RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。
5、指标:在训练和评估时,记录以下奖励指标:
rewards/chosen
: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by betarewards/rejected
: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by betarewards/accuracies
: mean of how often the chosen rewards are > than the corresponding rejected rewardsrewards/margins
: the mean difference between the chosen and corresponding rejected rewards
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {} (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics
DPO: Direct Preference Optimization 直接偏好优化(学习笔记)的更多相关文章
- KVM性能优化学习笔记
本学习笔记系列都是采用CentOS6.x操作系统,KVM虚拟机的管理也是采用virsh方式,网上的很多的文章都基于ubuntu高版本内核下,KVM的一些新的特性支持更好,本文只是记录了CentOS6. ...
- 深挖计算机基础:Linux性能优化学习笔记
参考极客时间专栏<Linux性能优化实战>学习笔记 一.CPU性能:13讲 Linux性能优化实战学习笔记:第二讲 Linux性能优化实战学习笔记:第三讲 Linux性能优化实战学习笔记: ...
- Pandas 性能优化 学习笔记
摘要 本文介绍了使用 Pandas 进行数据挖掘时常用的加速技巧. 实验环境 import numpy as np import pandas as pd print(np.__version__) ...
- mysql性能优化学习笔记(2)如何发现有问题的sql
一.使用mysql慢查询日志对有效率问题的sql进行监控 1)开启慢查询 show variables like ‘slow_query_log’;//查看是否开启慢查询日志 ...
- HIVE优化学习笔记
概述 之前写过关于hive的已经有两篇随笔了,但是作者依然还是一枚小白,现在把那些杂七杂八的总结一下,供以后查阅和总结.今天的文章介绍一下hive的优化.hive是好多公司都在使用的东西,也有好多大公 ...
- 燕十八MySQL优化学习笔记
观察 show status; 里面的这三个参数;Queries Threads_connected Threads_running判断周期性变化 -------------------------- ...
- mysql性能优化学习笔记-参数介绍及优化建议
MySQL服务器参数介绍 mysql参数介绍(客户端中执行),尽量只修改session级别的参数. 全局参数(新连接的session才会生效,原有已经连接的session不生效) set global ...
- mysql性能优化学习笔记
mysql性能优化 硬件对数据库的影响 CPU资源和可用内存大小 服务器硬件对mysql性能的影响 我们的应用是CPU密集型? 我们的应用的并发量如何? 数量比频率更好 64位使用32位的服务器版本 ...
- mysql优化学习笔记
优化sql的一般步骤 通过show status了解各种sql的执行频率 定位执行效率低的sql语句 通过explain分析效率低的sql 通过show profile分析sql 通过trace分析优 ...
- js性能优化--学习笔记
<高性能网站建设进阶指南>: 1.使用局部变量,避免深入作用域查找,局部变量是读写速度最快的:把函数中使用次数超过一次的对象属性和数组存储为局部变量是一个好方法:比如for循环中的.len ...
随机推荐
- 用 C# 写脚本 如何输出文件夹内所有文件名
大部分在 Windows 下的脚本都是使用 bat 或 cmd 写的,这部分的脚本对我来说可读性不好.这个可读性也是很主观的,对我来说用 C# 写脚本的可读性很强,但是换个小伙伴就不是了.在 .NET ...
- 阿里云OSS文件上传几种方法(主要是前端)
目录 零.准备 一.服务端签名后直传 1. 阿里云控制台配置 2. 后端接口开发(PHP) 3. 前端获取签名后上传 二.使用STS临时凭证进行上传 1. 后端接口开发(node) 2. 前端获取临时 ...
- 009_原理图中电气互连,Net alias,分页符,总线
009_原理图中电气互连,Net alias,分页符,总线 1.电气互连,就是画线. 2.端口名,适用同一页相连的端口. 3.分页符off page connector,适用于不同页的端口连接. 4. ...
- 阿里巴巴Canal常见问题:重复解析/Filter失效/消费落后
前言 Canal是阿里巴巴开源的数据库Binlog日志解析框架,主要用途是基于 MySQL 数据库增量日志解析,提供增量数据订阅和消费. 在之前我写的文章阿里开源MySQL中间件Canal快速入门中, ...
- 羽夏逆向破解日记簿——关于逆向epub格式转化器与思考
看前必读 本软件是商业软件,本篇文章仅仅介绍 逆向分析过程 和 关于开发软件防止逆向的思考 ,不会提供任何成品破解补丁或成品软件,仅限用于学习和研究目的,否则,一切后果自负.您必须在下载后的24个 ...
- 网络安全—SSL安全访问应用
文章目录 网络拓扑 部署CA服务器颁发证书 开启Web服务 安装IIS服务 修改Web默认网页 申请Web证书 前提准备 申请文件生成 申请web证书 开始安装web证书 客户机访问web默认网站 使 ...
- C 语言编程 — 高级数据类型 — 共用体
目录 文章目录 目录 前文列表 共用体 定义共用体 访问共用体成员 前文列表 <程序编译流程与 GCC 编译器> <C 语言编程 - 基本语法> <C 语言编程 - 基本 ...
- TypingLearn解决了我在学习英语中的一大痛点
上一次在博客园发贴还是在上一次(2021年),那个时候博客园就遇到了危机(被罚款).彼时在疫情期间,我个人生活也受到了影响,先后去了多个城市,最终在上海找到了 .NET Web开发的岗位,还是比较幸运 ...
- docker --link容器互联
目录 一.系统环境 二.docker容器互联概述 2.1 docker容器互联的三种方式 2.2 docker --link使用注意事项 2.3 docker --link原理 三.docker容器互 ...
- SpringMVC在返回JSON数据时出现406错误解决方案
在SpringMVC框架的使用中常常会使用@ResponseBody注解,修饰"处理器"(Controller的方法),这样在处理器在返回完毕后,就不走逻辑视图,而是将返回的对象转 ...