作者|huggingface

编译|VK

来源|Github

理念

Transformers是一个为NLP的研究人员寻求使用/研究/扩展大型Transformers模型的库。

该库的设计有两个强烈的目标:

  • 尽可能简单和快速使用:

    • 我们尽可能限制了要学习的面向对象抽象的类的数量,实际上几乎没有抽象,每个模型只需要使用三个标准类:配置、模型和tokenizer,
    • 所有这些类都可以通过使用公共的from_pretrained()实例化方法从预训练实例以简单统一的方式初始化,该方法将负责从库中下载,缓存和加载相关类提供的预训练模型或你自己保存的模型。
    • 因此,这个库不是构建神经网络模块的工具箱。如果您想扩展/构建这个库,只需使用常规的Python/PyTorch模块,并从这个库的基类继承,以重用诸如模型加载/保存等功能。
  • 提供最先进的模型与性能尽可能接近的原始模型:
    • 我们为每个架构提供了至少一个例子,该例子再现了上述架构的官方作者提供的结果
    • 代码通常尽可能地接近原始代码,这意味着一些PyTorch代码可能不那么pytorch化,因为这是转换TensorFlow代码后的结果。

其他几个目标:

  • 尽可能一致地暴露模型的内部:

    • 我们使用一个API来访问所有的隐藏状态和注意力权重,
    • 对tokenizer和基本模型的API进行了标准化,以方便在模型之间进行切换。
  • 结合一个主观选择的有前途的工具微调/调查这些模型:
    • 向词汇表和嵌入项添加新标记以进行微调的简单/一致的方法,
    • 简单的方法面具和修剪变压器头。

主要概念

该库是建立在三个类型的类为每个模型:

  • model类是目前在库中提供的8个模型架构的PyTorch模型(torch.nn.Modules),例如BertModel
  • configuration类,它存储构建模型所需的所有参数,例如BertConfig。您不必总是自己实例化这些配置,特别是如果您使用的是未经任何修改的预训练的模型,创建模型将自动负责实例化配置(它是模型的一部分)
  • tokenizer类,它存储每个模型的词汇表,并在要输送到模型的词汇嵌入索引列表中提供用于编码/解码字符串的方法,例如BertTokenizer

所有这些类都可以从预训练模型来实例化,并使用两种方法在本地保存:

  • from_pretraining()允许您从一个预训练版本实例化一个模型/配置/tokenizer,这个预训练版本可以由库本身提供(目前这里列出了27个模型),也可以由用户在本地(或服务器上)存储,
  • save_pretraining()允许您在本地保存模型/配置/tokenizer,以便可以使用from_pretraining()重新加载它。

我们将通过一些简单的快速启动示例来完成这个快速启动之旅,看看如何实例化和使用这些类。其余的文件分为两部分:

  • 主要的类详细介绍了三种主要类(配置、模型、tokenizer)的公共功能/方法/属性,以及一些作为训练工具提供的优化类,
  • 包引用部分详细描述了每个模型体系结构的每个类的所有变体,特别是调用它们时它们期望的输入和输出。

快速入门:使用

这里有两个例子展示了一些Bert和GPT2类以及预训练模型。

有关每个模型类的示例,请参阅完整的API参考。

BERT示例

让我们首先使用BertTokenizer从文本字符串准备一个标记化的输入(要输入给BERT的标记嵌入索引列表)

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练的模型标记器(词汇表)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 标记输入
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text) # 用“BertForMaskedLM”掩盖我们试图预测的标记`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]'] # 将标记转换为词汇索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 定义与第一句和第二句相关的句子A和B索引(见论文)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

让我们看看如何使用BertModel在隐藏状态下对输入进行编码:

# 加载预训练模型(权重)
model = BertModel.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') #预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
# Transformer模型总是输出元组。
# 有关所有输出的详细信息,请参见模型文档字符串。在我们的例子中,第一个元素是Bert模型最后一层的隐藏状态
encoded_layers = outputs[0]
# 我们已将输入序列编码为形状(批量大小、序列长度、模型隐藏维度)的FloatTensor
assert tuple(encoded_layers.shape) == (1, len(indexed_tokens), model.config.hidden_size)

以及如何使用BertForMaskedLM预测屏蔽的标记:

# 加载预训练模型(权重)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0] # 确认我们能预测“henson”
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

OpenAI GPT-2

下面是一个快速开始的例子,使用GPT2TokenizerGPT2LMHeadModel类以及OpenAI的预训练模型来预测文本提示中的下一个标记。

首先,让我们使用GPT2Tokenizer

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 编码输入
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text) # 转换为PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])

让我们看看如何使用GPT2LMHeadModel生成下一个跟在我们的文本后面的token:

# 加载预训练模型(权重)
model = GPT2LMHeadModel.from_pretrained('gpt2') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0] # 得到预测的下一个子词(在我们的例子中,是“man”这个词)
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'

每个模型架构(Bert、GPT、GPT-2、Transformer XL、XLNet和XLM)的每个模型类的示例,可以在文档中找到。

使用过去的GPT-2

以及其他一些模型(GPT、XLNet、Transfo XL、CTRL),使用pastmems属性,这些属性可用于防止在使用顺序解码时重新计算键/值对。它在生成序列时很有用,因为注意力机制的很大一部分得益于以前的计算。

下面是一个使用带pastGPT2LMHeadModel和argmax解码的完整工作示例(只能作为示例,因为argmax decoding引入了大量重复):

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2') generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[..., -1, :]) generated += [token.tolist()]
context = token.unsqueeze(0) sequence = tokenizer.decode(generated) print(sequence)

由于以前所有标记的键/值对都包含在past,因此模型只需要一个标记作为输入。

Model2Model示例

编码器-解码器架构需要两个标记化输入:一个用于编码器,另一个用于解码器。假设我们想使用Model2Model进行生成性问答,从标记将输入模型的问答开始。

import torch
from transformers import BertTokenizer, Model2Model # 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO) # 加载预训练模型(权重)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 编码输入(问题)
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question) # 编码输入(答案)
answer = "Jim Henson was a puppeteer"
encoded_answer = tokenizer.encode(answer) # 将输入转换为PyTorch张量
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])

让我们看看如何使用Model2Model获取与此(问题,答案)对相关联的loss值:

#为了计算损失,我们需要向解码器提供语言模型标签(模型生成的标记id)。
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('bert-base-uncased') # 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
labels_tensor = labels_tensor.to('cuda')
model.to('cuda') # 预测每个层的隐藏状态特征
with torch.no_grad():
# 有关输入的详细信息,请参见models文档字符串
outputs = model(question_tensor, answer_tensor, decoder_lm_labels=labels_tensor)
# Transformers模型总是输出元组。
# 有关所有输出的详细信息,请参见models文档字符串
# 在我们的例子中,第一个元素是LM损失的值
lm_loss = outputs[0]

此损失可用于对Model2Model的问答任务进行微调。假设我们对模型进行了微调,现在让我们看看如何生成答案:

# 让我们重复前面的问题
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question)
question_tensor = torch.tensor([encoded_question]) # 这次我们试图生成答案,所以我们从一个空序列开始
answer = "[CLS]"
encoded_answer = tokenizer.encode(answer, add_special_tokens=False)
answer_tensor = torch.tensor([encoded_answer]) # 加载预训练模型(权重)
model = Model2Model.from_pretrained('fine-tuned-weights')
model.eval() # 如果你有GPU,把所有东西都放在cuda上
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
model.to('cuda') # 预测所有标记
with torch.no_grad():
outputs = model(question_tensor, answer_tensor)
predictions = outputs[0] # 确认我们能预测“jim”
predicted_index = torch.argmax(predictions[0, -1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'jim'

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:

http://pytorch.panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

Transformers 快速入门 | 一的更多相关文章

  1. Node.js快速入门

    Node.js是什么? Node.js是建立在谷歌Chrome的JavaScript引擎(V8引擎)的Web应用程序框架. 它的最新版本是:v0.12.7(在编写本教程时的版本).Node.js在官方 ...

  2. Web Api 入门实战 (快速入门+工具使用+不依赖IIS)

    平台之大势何人能挡? 带着你的Net飞奔吧!:http://www.cnblogs.com/dunitian/p/4822808.html 屁话我也就不多说了,什么简介的也省了,直接简单概括+demo ...

  3. SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=》提升)

     SignalR快速入门 ~ 仿QQ即时聊天,消息推送,单聊,群聊,多群公聊(基础=>提升,5个Demo贯彻全篇,感兴趣的玩才是真的学) 官方demo:http://www.asp.net/si ...

  4. 前端开发小白必学技能—非关系数据库又像关系数据库的MongoDB快速入门命令(2)

    今天给大家道个歉,没有及时更新MongoDB快速入门的下篇,最近有点小忙,在此向博友们致歉.下面我将简单地说一下mongdb的一些基本命令以及我们日常开发过程中的一些问题.mongodb可以为我们提供 ...

  5. 【第三篇】ASP.NET MVC快速入门之安全策略(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  6. 【番外篇】ASP.NET MVC快速入门之免费jQuery控件库(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

  7. Mybatis框架 的快速入门

    MyBatis 简介 什么是 MyBatis? MyBatis 是支持普通 SQL 查询,存储过程和高级映射的优秀持久层框架.MyBatis 消除 了几乎所有的 JDBC 代码和参数的手工设置以及结果 ...

  8. grunt快速入门

    快速入门 Grunt和 Grunt 插件是通过 npm 安装并管理的,npm是 Node.js 的包管理器. Grunt 0.4.x 必须配合Node.js >= 0.8.0版本使用.:奇数版本 ...

  9. 【第一篇】ASP.NET MVC快速入门之数据库操作(MVC5+EF6)

    目录 [第一篇]ASP.NET MVC快速入门之数据库操作(MVC5+EF6) [第二篇]ASP.NET MVC快速入门之数据注解(MVC5+EF6) [第三篇]ASP.NET MVC快速入门之安全策 ...

随机推荐

  1. 国际控制报文协议ICMP

    国际控制报文协议ICMP ICMP简介 ICMP 用于主机或路由器报告差错情况和提供有关异常情况的报告(检测网络错误). ICMP 不是高层协议,而是 IP 层的协议. ICMP 报文的格式 ICMP ...

  2. 初学Qt——QTableView+QSqlqueryModel

    我们在显示报表时可以用到上面两个类来实现,QTableView负责对视图显示:QSqlqueryModel则负责数据模块. 这里数据查询使用QSqlqueryModel主要是这个类可以通过自己写的查询 ...

  3. react-native 使用leanclound消息推送

    iOS消息推送的基本流程 1.注册:为应用程序申请消息推送服务.此时你的设备会向APNs服务器发送注册请求.2. APNs服务器接受请求,并将deviceToken返给你设备上的应用程序 3.客户端应 ...

  4. JZOJ 3453.【NOIP2013中秋节模拟】连通块(connect)

    3453.[NOIP2013中秋节模拟]连通块(connect) Time Limits: 1000 ms Memory Limits: 262144 KB (File IO): input:conn ...

  5. iOS8 定位失败问题

    iOS7升级到iOS8后,百度地图 iOS SDK 中的定位功能不可用,给广大开发者带来了不便,在此向大家分享一个方法来解决次问题.(官方的适配工作还在进行中,不久将会和广大开发者见面) 1.在inf ...

  6. django 从零开始 11 根据时间戳加密数据

    django自带一个加密的方法signer,对数据进行一个加密 一般这种方式用于账号密码邮箱找回,或者token设置 class TimestampSigner(Signer): def timest ...

  7. NLP(二十六)限定领域的三元组抽取的一次尝试

      本文将会介绍笔者在2019语言与智能技术竞赛的三元组抽取比赛方面的一次尝试.由于该比赛早已结束,笔者当时也没有参加这个比赛,因此没有测评成绩,我们也只能拿到训练集和验证集.但是,这并不耽误我们在这 ...

  8. 4000字干货长文!从校招和社招的角度说说如何准备Java后端大厂面试?

    插个题外话,为了写好这篇文章内容,我自己前前后后花了一周的时间来总结完善,文章内容应该适用于每一个学习 Java 的朋友!我觉得这篇文章的很多东西也是我自己写给自己的,比如从大厂招聘要求中我们能看到哪 ...

  9. 从零搭建Spring Cloud Gateway网关(一)

    新建Spring Boot项目 怎么新建Spring Boot项目这里不再具体赘述,不会的可以翻看下之前的博客或者直接百度.这里直接贴出对应的pom文件. pom依赖如下: <?xml vers ...

  10. XiaoQi.Study 项目(三)

    一.配置跨域 1.首先注册跨域要求 ,(可访问的IP.端口) //注册跨域 services.AddCors(options => { options.AddPolicy("XiaoQ ...