TRL(Transformer Reinforcement Learning) PPO Trainer 学习笔记
(1) PPO Trainer
TRL支持PPO Trainer通过RL训练语言模型上的任何奖励信号。奖励信号可以来自手工制作的规则、指标或使用奖励模型的偏好数据。要获得完整的示例,请查看examples/notebooks/gpt2-sentiment.ipynb。Trainer很大程度上受到了原始OpenAI learning to summarize work的启发。
第一步是训练你的SFT模型(参见 SFTTrainer),以确保我们训练的数据在PPO算法的分布中。此外,我们需要训练一个奖励模型(见RewardTrainer),该模型将用于使用PPO算法优化SFT模型。
(2) 期望的数据集格式
The PPOTrainer expects to align a generated response with a query given the rewards obtained from the Reward model. 在 PPO 算法的每个步骤中,我们从数据集中采样一批提示,然后使用这些提示生成 SFT 模型的响应。 接下来,奖励模型用于计算生成的响应的奖励。 最后,这些奖励用于使用 PPO 算法优化 SFT 模型。
因此,数据集应包含一个文本列,我们可以将其重命名为query。 优化 SFT 模型所需的每个其他数据点都是在训练循环期间获得的。
Here is an example with the HuggingFaceH4/cherry_picked_prompts dataset:
from datasets import load_dataset dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])
得到数据集的以下子集:
ppo_dataset_dict = {
"query": [
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren’t birds real?",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating? "
]
}
(3) 使用 PPOTrainer
有关详细示例,请查看 examples/notebooks/gpt2-sentiment.ipynb
。在抽象层面上,需要用一个我们希望训练的model来初始化PPOTrainer。此外,我们需要一个参考reward_model,使用它对生成的响应进行评级。
初始化PPOTrainer:
PPOConfig数据类控制PPO算法和训练器的所有超参数和设置。
from trl import PPOConfig config = PPOConfig(
model_name="gpt2",
learning_rate=1.41e-5,
)
现在我们可以初始化我们的模型了。 请注意,PPO 还需要一个参考模型,但该模型是由“PPOTrainer”自动生成的。 该模型可以按如下方式初始化:
from transformers import AutoTokenizer from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name) tokenizer.pad_token = tokenizer.eos_token
如上所述,可以使用任何针对字符串返回标量值的函数来生成奖励(the reward can be generated using any function that returns a single value for a string),无论是简单的规则(例如字符串的长度)、度量(例如 BLEU)还是基于人类偏好的奖励模型。 在此示例中,我们使用奖励模型并使用 Transformers.pipeline 对其进行初始化以方便使用。
from transformers import pipeline reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
最后,我们使用tokenizer对数据集进行pretokenize,以确保我们可以在训练循环期间有效地生成响应:
def tokenize(sample):
sample["input_ids"] = tokenizer.encode(sample["query"])
return sample dataset = dataset.map(tokenize, batched=False)
现在我们准备使用定义的配置、数据集和模型来初始化 PPOTrainer。
from trl import PPOTrainer ppo_trainer = PPOTrainer(
model=model,
config=config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
(4) 开始训练loop
由于 PPOTrainer 在每个执行步骤中都需要主动奖励,因此我们需要定义一种在 PPO 算法的每个步骤中获取奖励的方法。 在此示例中,我们将使用上面初始化的情绪奖励模型。
为了指导生成过程,我们使用在每个步骤中传递给 SFT 模型的 model.generate 方法的 Generation_kwargs。 可以在 here找到更详细的示例。参考
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
然后,我们可以循环数据集中的所有示例并为每个查询生成响应。 然后,我们使用reward_model计算每个生成响应的奖励,并将这些奖励传递给ppo_trainer.step方法。 然后 ppo_trainer.step 方法将使用 PPO 算法优化 SFT 模型。
from tqdm import tqdm for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"] #### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] #### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] #### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards) #### Save model
ppo_trainer.save_model("my_ppo_model")
TRL(Transformer Reinforcement Learning) PPO Trainer 学习笔记的更多相关文章
- 阅读《LEARNING HARD C#学习笔记》知识点总结与摘要系列文章索引
从发表第一篇文章到最后一篇文章,时间间隔有整整一个月,虽只有5篇文章,但每一篇文章都是我吸收<LEARNING HARD C#学习笔记>这本书的内容要点及网上各位大牛们的经验,没有半点废话 ...
- 阅读《LEARNING HARD C#学习笔记》知识点总结与摘要三
最近工作较忙,手上有几个项目等着我独立开发设计,所以平时工作日的时候没有太多时间,下班累了就不想动,也就周末有点时间,今天我花了一个下午的时间来继续总结与整理书中要点,在整理的过程中,发现了书中的一些 ...
- 阅读《LEARNING HARD C#学习笔记》知识点总结与摘要二
今天继续分享我的阅读<LEARNING HARD C#学习笔记>知识点总结与摘要二,仍然是基础知识,但可温故而知新. 七.面向对象 三大基本特性: 封装:把客观事物封装成类,并隐藏类的内部 ...
- 阅读《LEARNING HARD C#学习笔记》知识点总结与摘要一
本人有幸在Learning Hard举行的整点抢书活动<Learninghard C#学习笔记>回馈网友,免费送书5本中免费获得了一本<LEARNING HARD C#学习笔记> ...
- Deep learning with Python 学习笔记(11)
总结 机器学习(machine learning)是人工智能的一个特殊子领域,其目标是仅靠观察训练数据来自动开发程序[即模型(model)].将数据转换为程序的这个过程叫作学习(learning) 深 ...
- Deep learning with Python 学习笔记(10)
生成式深度学习 机器学习模型能够对图像.音乐和故事的统计潜在空间(latent space)进行学习,然后从这个空间中采样(sample),创造出与模型在训练数据中所见到的艺术作品具有相似特征的新作品 ...
- Deep learning with Python 学习笔记(9)
神经网络模型的优化 使用 Keras 回调函数 使用 model.fit()或 model.fit_generator() 在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推 ...
- Deep learning with Python 学习笔记(8)
Keras 函数式编程 利用 Keras 函数式 API,你可以构建类图(graph-like)模型.在不同的输入之间共享某一层,并且还可以像使用 Python 函数一样使用 Keras 模型.Ker ...
- Deep learning with Python 学习笔记(7)
介绍一维卷积神经网络 卷积神经网络能够进行卷积运算,从局部输入图块中提取特征,并能够将表示模块化,同时可以高效地利用数据.这些性质让卷积神经网络在计算机视觉领域表现优异,同样也让它对序列处理特别有效. ...
- Deep learning with Python 学习笔记(6)
本节介绍循环神经网络及其优化 循环神经网络(RNN,recurrent neural network)处理序列的方式是,遍历所有序列元素,并保存一个状态(state),其中包含与已查看内容相关的信息. ...
随机推荐
- Linux内核之I2C协议
I2C协议标准文档 THE I2C-BUS SPECIFICATION VERSION 2.1 JANUARY 2000: https://www.csd.uoc.gr/~hy428/reading/ ...
- 史上功能最全的Java权限认证框架!
大家好,我是 Java 陈序员.权限认证是我们日常开发绕不过的话题,这是因为我们的应用程序需要防护,防止被窜入和攻击. 在 Java 后端开发中,实现权限认证有很多种方案可以选择,一个拦截器.过滤器也 ...
- 羽夏闲谈——NewCode
前言 在工作学习中,我配置好了一个VSCode,学习C语言,需要经常性的创建代码文件,而往往这里面有一个固定的模板,比如下面: #define _CRT_SECURE_NO_WARNINGS #i ...
- kali使用apt-get update 出现数字签名失效
kali使用apt-get update 出现数字签名失效 下载签名:wget archive.kali.org/archive-key.asc 安装签名:apt-key add archive-ke ...
- WEB服务与NGINX(10)-NGINX访问控制功能
目录 1.NGINX访问控制功能 1.1 基于ip地址的访问控制 1.2 基于用户名密码的认证 1.NGINX访问控制功能 nginx的访问控制有两种方式: 基于ip进行限制,由ngx_http_ac ...
- vue中v-for说明
v-if vs v-show区别v-if:每次显示与否,都会执行销毁和重建,渲染开销较大v-show:始终会被渲染并保留在DOM中.只是简单地切换display属性.频繁切换的时候用v-if,较少切换 ...
- NETCore中实现一个轻量无负担的极简任务调度ScheduleTask
至于任务调度这个基础功能,重要性不言而喻,大多数业务系统都会用到,世面上有很多成熟的三方库比如Quartz,Hangfire,Coravel 这里我们不讨论三方的库如何使用 而是从0开始自己制作一个简 ...
- 智能制造 | AIRIOT智慧工厂管理解决方案
工厂生产运转中,设备数量多,环境复杂.企业往往需要承担很高的维修.保养.备件和人力成本.传统的工厂改革遇到了诸多前所未有的挑战: 1.管理系统较多,数据隔离,系统集成困难重重: 2.大量老旧设 ...
- openstack以后需要研究一下的知识
1. openvt是一个用于在虚拟终端上启动程序的命令行工具.它允许用户在一个新的虚拟终端(VT)上启动一个程序,并将标准输入.输出和错误输出定向到该终端. openvt的用法如下: 打开一个虚拟终端 ...
- 京东二面:Sychronized的锁升级过程是怎样的
引言 Java作为主流的面向对象编程语言,提供了丰富的并发工具来帮助开发者解决多线程环境下的数据一致性问题.其中,内置的关键字"Synchronized"扮演了至关重要的角色,它能 ...