TRL(Transformer Reinforcement Learning)是一个使用强化学习来训练Transformer语言模型和Stable Diffusion模型的Python类库工具集,听上去很抽象,但如果说主要是做SFT(Supervised Fine-tuning)、RM(Reward Modeling)、RLHF(Reinforcement Learning from Human Feedback)和PPO(Proximal Policy Optimization)等的话,肯定就很熟悉了。最重要的是TRL构建于transformers库之上,两者均由Hugging Face公司开发。

一.TRL类库

1.TRL类库介绍

简单理解就是可以通过TRL库做RLHF训练,如下所示:


(1)SFTTrainer:是一个轻量级、友好的transformers Trainer包装器,可轻松在自定义数据集上微调语言模型或适配器。

(2)RewardTrainer:是一个轻量级的transformers Trainer包装器,可轻松为人类偏好(奖励建模)微调语言模型。

(3)PPOTrainer:一个PPO训练器,用于语言模型,只需要(query, response, reward)三元组来优化语言模型。

(4)AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个带有额外标量输出的transformer模型,每个token都可以用作强化学习中的值函数。

(5)Examples:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j以减少毒性,Stack-Llama例子等。

2.PPO工作原理

  通过PPO对语言模型进行微调大致包括三个步骤:

(1)Rollout:语言模型根据query生成response或continuation,query可以是一个句子的开头。

(2)Evaluation:使用函数、模型、人类反馈或它们的某些组合对查询和响应进行评估。重要的是,此过程应为每个query/response对生成一个标量值。

(3)Optimization:这是最复杂的部分。在优化步骤中,query/response对用于计算序列中token的对数概率。这是使用经过训练的模型和Reference model完成的,Reference model通常是微调前的预训练模型。两个输出之间的KL散度用作额外的奖励信号,以确保生成的response不会偏离Reference model太远。然后使用PPO训练Active model。

二.TRL安装和使用方式

1.TRL安装

# 直接安装包
pip install trl

# 从源码安装
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .

2.SFTTrainer使用方式

  SFTTrainer是围绕transformer Trainer的轻量级封装,可以轻松微调自定义数据集上的语言模型或适配器。如下所示:

# 导入Python包
from datasets import load_dataset
from trl import SFTTrainer

# 加载imdb数据集
dataset = load_dataset("imdb", split="train")

# 得到trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# 开始训练
trainer.train()

3.RewardTrainer使用方式

  RewardTrainer是围绕transformers Trainer的封装,可以轻松在自定义偏好数据集上微调奖励模型或适配器。如下所示:

# 导入Python包
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# 加载模型和数据集,数据集需要为指定格式
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# 得到trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# 开始训练
trainer.train()

4.PPOTrainer使用方式

  query通过语言模型输出一个response,然后对其进行评估。评估可以人类反馈,也可以是另一个模型的输出。如下所示:

# 导入Python包
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# 首先加载模型,然后创建参考模型
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

# 初始化ppo配置对象
ppo_config = PPOConfig(
    batch_size=1,
)

# 编码一个query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# 得到模型response
response_tensor  = respond_to_batch(model, query_tensor)

# 创建一个ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# 为response定义一个reward(人类反馈或模型输出奖励) 
reward = [torch.tensor(1.0)]

# 使用ppo训练一步模型
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

参考文献:

[1]https://github.com/huggingface/trl

[2]https://huggingface.co/docs/trl/v0.7.1/en/index

Llama2-Chinese项目:8-TRL资料整理的更多相关文章

  1. iOS 开发学习资料整理(持续更新)

      “如果说我看得比别人远些,那是因为我站在巨人们的肩膀上.” ---牛顿   iOS及Mac开源项目和学习资料[超级全面] http://www.kancloud.cn/digest/ios-mac ...

  2. zz 圣诞丨太阁所有的免费算法视频资料整理

    首发于 太阁实验室 关注专栏   写文章     圣诞丨太阁所有的免费算法视频资料整理 Ray Cao· 12 小时前 感谢大家一年以来对太阁实验室的支持,我们特地整理了在过去一年中我们所有的原创算法 ...

  3. 【转】iOS超全开源框架、项目和学习资料汇总

    iOS超全开源框架.项目和学习资料汇总(1)UI篇iOS超全开源框架.项目和学习资料汇总(2)动画篇iOS超全开源框架.项目和学习资料汇总(3)网络和Model篇iOS超全开源框架.项目和学习资料汇总 ...

  4. H.264的一些资料整理

    本文转载自 http://blog.csdn.net/ljzcom/article/details/7258978, 如有需要,请移步查看. Technorati 标签: H.264 资料整理 --- ...

  5. Java 学习资料整理

    Java 学习资料整理 Java 精品学习视频教程下载汇总 Java视频教程 孙鑫Java无难事 (全12CD) Java视频教程 即学即会java 上海交大 Java初级编程基础 共25讲下载 av ...

  6. word2vec剖析,资料整理备存

    声明:word2vec剖析,资料整理备存,以下资料均为转载,膜拜大神,仅作学术交流之用. word2vec是google最新发布的深度学习工具,它利用神经网络将单词映射到低维连续实数空间,又称为单词嵌 ...

  7. F4NNIU 的 Docker 学习资料整理

    F4NNIU 的 Docker 学习资料整理 Docker 介绍 以下来自 Wikipedia Docker是一个开放源代码软件项目,让应用程序部署在软件货柜下的工作可以自动化进行,借此在Linux操 ...

  8. Burpsuite 资料整理

    Burpsuite 资料整理, 整到一起比较方便.大家有更多关于Burpsuite的Tip请一起增量.谢谢! 插件 序号 名称 功能 参考文档 1 Turbo intruder 并发 https:// ...

  9. iOS 学习资料整理

    iOS学习资料整理 https://github.com/NunchakusHuang/trip-to-iOS 很好的个人博客 http://www.cnblogs.com/ygm900/ 开发笔记 ...

  10. 转:基于IOS上MDM技术相关资料整理及汇总

    一.MDM相关知识: MDM (Mobile Device Management ),即移动设备管理.在21世纪的今天,数据是企业宝贵的资产,安全问题更是重中之重,在移动互联网时代,员工个人的设备接入 ...

随机推荐

  1. Composite 组合模式简介与 C# 示例【结构型3】【设计模式来了_8】

    〇.简介 1.什么是组合设计模式? 一句话解释:   针对树形结构的任意节点,都实现了同一接口,他们具有相同的操作,可以通过某一操作来遍历全部节点. 组合模式通过使用树形结构来组合对象,用来表示部分以 ...

  2. [NSSRound#1 Basic]basic_check

    打开网站,发现啥也没有: 就用dirsearch扫了一遍.发现还是没有有用信息: 只有再另找方法: 再用nikto扫一次: 发现一个put方法,就用put上传一个一句话木马:可以用插件restlien ...

  3. Intervals 题解

    Intervals 题目大意 给定 \(m\) 条形如 \((l_i,r_i,a_i)\) 的规则,你需要求出一个长为 \(n\) 的分数最大的 01 串的分数,其中一个 01 串 \(A\) 的分数 ...

  4. 最新 2023.2 版本 IDEA 永久破解教程,IDEA 破解补丁永久激活(亲测有效)

    最近 jetbrains 官方发布了 2023.2 版本的 IDEA,之前的激活方法并不支持这个新的版本. 下面是最新的激活教程,激活步骤和之前是类似的,只是换用了不同的补丁文件. 本教程支持 Jet ...

  5. CSS z-index属性层重叠顺序

    作者:WangMin 格言:努力做好自己喜欢的每一件事 对于所有定位,最后都不免遇到两个元素试图放在同一位置上的情况.显然,其中一个必须遮住另一个.但是如何控制哪个元素放在上层,这就出现了z-inde ...

  6. jap复制一条数据插入数据库,报:identifier of an instance of com.kxkd.shop.entity.goods.GoodsSpu was alt

    因为修改了jpa实体id 可以先使用springframework的BeanUtils复制一个相同的对象 BeanUtils.copyProperties(source, target); //复制属 ...

  7. 【pwn】[SWPUCTF 2022 新生赛]InfoPrinter--格式化字符串漏洞,got表劫持,data段修改

    下载附件,checksec检查程序保护情况: No RELRO,说明got表可修改 接下来看主程序: 函数逻辑还是比较简单,14行出现格式化字符串漏洞,配合pwntools的fmtstr_payloa ...

  8. re1-100

    虽然关键的判断函数和"成功"的提示也在这里,但是具体对输入flag的操作却在后面 看到对数组bufParentRead[1]开始赋值"53fc275d81",b ...

  9. DP:三角形的最小路径和

    给定一个三角形,找出自顶向下的最小路径和.每一步只能移动到下一行中相邻的结点上. 例如,给定三角形: [     [2],    [3,4],   [6,5,7],  [4,1,8,3]] 自顶向下的 ...

  10. .NET周刊【11月第3期 2023-11-19】

    国内文章 .NET8.0 AOT 经验分享 FreeSql/FreeRedis/FreeScheduler 均已通过测试 https://www.cnblogs.com/FreeSql/p/17836 ...