TRL(Transformer Reinforcement Learning)是一个使用强化学习来训练Transformer语言模型和Stable Diffusion模型的Python类库工具集,听上去很抽象,但如果说主要是做SFT(Supervised Fine-tuning)、RM(Reward Modeling)、RLHF(Reinforcement Learning from Human Feedback)和PPO(Proximal Policy Optimization)等的话,肯定就很熟悉了。最重要的是TRL构建于transformers库之上,两者均由Hugging Face公司开发。

一.TRL类库

1.TRL类库介绍

简单理解就是可以通过TRL库做RLHF训练,如下所示:


(1)SFTTrainer:是一个轻量级、友好的transformers Trainer包装器,可轻松在自定义数据集上微调语言模型或适配器。

(2)RewardTrainer:是一个轻量级的transformers Trainer包装器,可轻松为人类偏好(奖励建模)微调语言模型。

(3)PPOTrainer:一个PPO训练器,用于语言模型,只需要(query, response, reward)三元组来优化语言模型。

(4)AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个带有额外标量输出的transformer模型,每个token都可以用作强化学习中的值函数。

(5)Examples:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j以减少毒性,Stack-Llama例子等。

2.PPO工作原理

  通过PPO对语言模型进行微调大致包括三个步骤:

(1)Rollout:语言模型根据query生成response或continuation,query可以是一个句子的开头。

(2)Evaluation:使用函数、模型、人类反馈或它们的某些组合对查询和响应进行评估。重要的是,此过程应为每个query/response对生成一个标量值。

(3)Optimization:这是最复杂的部分。在优化步骤中,query/response对用于计算序列中token的对数概率。这是使用经过训练的模型和Reference model完成的,Reference model通常是微调前的预训练模型。两个输出之间的KL散度用作额外的奖励信号,以确保生成的response不会偏离Reference model太远。然后使用PPO训练Active model。

二.TRL安装和使用方式

1.TRL安装

# 直接安装包
pip install trl

# 从源码安装
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .

2.SFTTrainer使用方式

  SFTTrainer是围绕transformer Trainer的轻量级封装,可以轻松微调自定义数据集上的语言模型或适配器。如下所示:

# 导入Python包
from datasets import load_dataset
from trl import SFTTrainer

# 加载imdb数据集
dataset = load_dataset("imdb", split="train")

# 得到trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# 开始训练
trainer.train()

3.RewardTrainer使用方式

  RewardTrainer是围绕transformers Trainer的封装,可以轻松在自定义偏好数据集上微调奖励模型或适配器。如下所示:

# 导入Python包
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# 加载模型和数据集,数据集需要为指定格式
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# 得到trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# 开始训练
trainer.train()

4.PPOTrainer使用方式

  query通过语言模型输出一个response,然后对其进行评估。评估可以人类反馈,也可以是另一个模型的输出。如下所示:

# 导入Python包
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# 首先加载模型,然后创建参考模型
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

# 初始化ppo配置对象
ppo_config = PPOConfig(
    batch_size=1,
)

# 编码一个query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# 得到模型response
response_tensor  = respond_to_batch(model, query_tensor)

# 创建一个ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# 为response定义一个reward(人类反馈或模型输出奖励) 
reward = [torch.tensor(1.0)]

# 使用ppo训练一步模型
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

参考文献:

[1]https://github.com/huggingface/trl

[2]https://huggingface.co/docs/trl/v0.7.1/en/index

Llama2-Chinese项目:8-TRL资料整理的更多相关文章

  1. iOS 开发学习资料整理(持续更新)

      “如果说我看得比别人远些,那是因为我站在巨人们的肩膀上.” ---牛顿   iOS及Mac开源项目和学习资料[超级全面] http://www.kancloud.cn/digest/ios-mac ...

  2. zz 圣诞丨太阁所有的免费算法视频资料整理

    首发于 太阁实验室 关注专栏   写文章     圣诞丨太阁所有的免费算法视频资料整理 Ray Cao· 12 小时前 感谢大家一年以来对太阁实验室的支持,我们特地整理了在过去一年中我们所有的原创算法 ...

  3. 【转】iOS超全开源框架、项目和学习资料汇总

    iOS超全开源框架.项目和学习资料汇总(1)UI篇iOS超全开源框架.项目和学习资料汇总(2)动画篇iOS超全开源框架.项目和学习资料汇总(3)网络和Model篇iOS超全开源框架.项目和学习资料汇总 ...

  4. H.264的一些资料整理

    本文转载自 http://blog.csdn.net/ljzcom/article/details/7258978, 如有需要,请移步查看. Technorati 标签: H.264 资料整理 --- ...

  5. Java 学习资料整理

    Java 学习资料整理 Java 精品学习视频教程下载汇总 Java视频教程 孙鑫Java无难事 (全12CD) Java视频教程 即学即会java 上海交大 Java初级编程基础 共25讲下载 av ...

  6. word2vec剖析,资料整理备存

    声明:word2vec剖析,资料整理备存,以下资料均为转载,膜拜大神,仅作学术交流之用. word2vec是google最新发布的深度学习工具,它利用神经网络将单词映射到低维连续实数空间,又称为单词嵌 ...

  7. F4NNIU 的 Docker 学习资料整理

    F4NNIU 的 Docker 学习资料整理 Docker 介绍 以下来自 Wikipedia Docker是一个开放源代码软件项目,让应用程序部署在软件货柜下的工作可以自动化进行,借此在Linux操 ...

  8. Burpsuite 资料整理

    Burpsuite 资料整理, 整到一起比较方便.大家有更多关于Burpsuite的Tip请一起增量.谢谢! 插件 序号 名称 功能 参考文档 1 Turbo intruder 并发 https:// ...

  9. iOS 学习资料整理

    iOS学习资料整理 https://github.com/NunchakusHuang/trip-to-iOS 很好的个人博客 http://www.cnblogs.com/ygm900/ 开发笔记 ...

  10. 转:基于IOS上MDM技术相关资料整理及汇总

    一.MDM相关知识: MDM (Mobile Device Management ),即移动设备管理.在21世纪的今天,数据是企业宝贵的资产,安全问题更是重中之重,在移动互联网时代,员工个人的设备接入 ...

随机推荐

  1. 一些常见小程序的UI设计分享

    外卖优惠券小程序的UI设计 电子商城系统UI分享 A B C

  2. docker入门加实战—Docker镜像和Dockerfile语法

    docker入门加实战-Docker镜像和Dockerfile语法 镜像 镜像就是包含了应用程序.程序运行的系统函数库.运行配置等文件的文件包.构建镜像的过程其实就是把上述文件打包的过程. 镜像结构 ...

  3. NFT(数字藏品)热度没了?这玩意是机会还是泡沫?

    感谢你阅读本文! 大家好,今天分享一下NFT(数字藏品)这个领域,虽然今天的NFT已经没有之前那么火热,不过市场上依旧还是有很多平台存在,有人离开,也有人不断进来,所以很有必要再分析一番. 需要注意的 ...

  4. 栩栩如生,音色克隆,Bert-vits2文字转语音打造鬼畜视频实践(Python3.10)

    诸公可知目前最牛逼的TTS免费开源项目是哪一个?没错,是Bert-vits2,没有之一.它是在本来已经极其强大的Vits项目中融入了Bert大模型,基本上解决了VITS的语气韵律问题,在效果非常出色的 ...

  5. Hall定理(霍尔定理)证明及推广

    引言 网络上有许多Hall定理的证明,但是对于Hall定理的几个推广的介绍却少之又少,因此本文来简单介绍一下 注:为了使这篇文章看起来简单易懂,本文将不会使用图论语言,会图论的朋友们可以自行翻译为图论 ...

  6. Go 常用标准库之 fmt 介绍与基本使用

    Go 常用标准库之 fmt 介绍与基本使用 目录 Go 常用标准库之 fmt 介绍与基本使用 一.介绍 二.向外输出 2.1 Print 系列 2.2 Fprint 系列 2.3 Sprint 系列 ...

  7. 前端JavaScript编码规范 和react编码规范

    JavaScript编码规范 点击链接查看:https://github.com/ecomfe/spec/blob/master/javascript-style-guide.md 前端React编码 ...

  8. 【Azure Durable Function】PowerShell Activity 函数遇见 Newtonsoft.Json.JsonReaderException: The reader's MaxDepth of 64 has been exceeded.

    问题描述 创建PowerShell Azure Durable Function,执行大量的PowerShell脚本操作Azure Resource,遇见了一个非常非常奇怪的问题: Function ...

  9. 贪心算法:7-6 Swan学院社团招新

    Swan学院社团招新,招新宣讲会分散在不同时间段,大一新生小花花想知道自己最多能完整的参加多少个招新宣讲会(参加一个招新宣讲会的时候不能中断或离开). [问题说明]这个问题是对几个相互竞争的招新宣讲会 ...

  10. WPF --- 如何以Binding方式隐藏DataGrid列

    引言 如题,如何以Binding的方式动态隐藏DataGrid列? 预想方案 像这样: 先在ViewModel创建数据源 People 和控制列隐藏的 IsVisibility,这里直接以 MainW ...