作者|huggingface

编译|VK

来源|Github

本章介绍使用Transformers库时最常见的用例。可用的模型允许许多不同的配置,并且在用例中具有很强的通用性。这里介绍了最简单的方法,展示了诸如问答、序列分类、命名实体识别等任务的用法。

这些示例利用Auto Model,这些类将根据给定的checkpoint实例化模型,并自动选择正确的模型体系结构。有关详细信息,请查看:AutoModel文档。请随意修改代码,使其更具体,并使其适应你的特定用例。

  • 为了使模型能够在任务上良好地执行,必须从与该任务对应的checkpoint加载模型。这些checkpoint通常是在大量数据上预先训练的,并针对特定任务进行微调。这意味着:并非所有模型都针对所有任务进行了微调。如果要对特定任务的模型进行微调,可以利用examples目录中的run\$task.py脚本。
  • 微调模型是在特定的数据集上微调的。此数据集可能与你的用例和域重叠,也可能不重叠。如前所述,你可以利用示例脚本来微调模型,也可以创建自己的训练脚本。

为了对任务进行推理,库提供了几种机制:

  • 管道是非常易于使用的抽象,只需要两行代码。
  • 直接将模型与Tokenizer(PyTorch/TensorFlow)结合使用来使用模型的完整推理。这种机制稍微复杂,但是更强大。

这里展示了两种方法。

请注意,这里介绍的所有任务都利用了在预训练模型针对特定任务进行微调后的模型。加载未针对特定任务进行微调的checkpoint时,将只加载transformer层,而不会加载用于该任务的附加层,从而随机初始化该附加层的权重。这将产生随机输出。

序列分类

序列分类是根据已经给定的类别然后对序列进行分类的任务。序列分类的一个例子是GLUE数据集,它就是完全基于该任务的。如果你想在GLUE序列分类任务上微调模型,可以利用run_GLUE.pyrun_tf_GLUE.py脚本。

下面是一个使用管道进行情绪分析的例子:识别该序列是积极的还是消极的。它利用sst2上的微调模型,这是一个GLUE任务。

from transformers import pipeline

nlp = pipeline("sentiment-analysis")

print(nlp("I hate you"))
print(nlp("I love you"))

这将返回一个标签(“积极”或“消极”)和一个分数,如下所示:

[{'label': 'NEGATIVE', 'score': 0.9991129}]
[{'label': 'POSITIVE', 'score': 0.99986565}]

下面是一个使用模型进行序列分类的示例,以确定两个序列是否是彼此的解释。该过程如下:

  • 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
  • 从这两句话中构建一个序列,使用正确的特定于模型的分隔符标记类型id和注意力掩码(encode()和encode_plus()处理这个问题)
  • 将这个序列传递到模型中,以便将其分类到两个可用的类中的一个:0(不是解释)和1(是解释)
  • 计算结果的softmax获取类的概率
  • 打印结果

Pytorch代码

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc") classes = ["not paraphrase", "is paraphrase"] sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan" paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt") paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0] paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0] print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%") print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")

TensorFlow代码

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc") classes = ["not paraphrase", "is paraphrase"] sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan" paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="tf")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="tf") paraphrase_classification_logits = model(paraphrase)[0]
not_paraphrase_classification_logits = model(not_paraphrase)[0] paraphrase_results = tf.nn.softmax(paraphrase_classification_logits, axis=1).numpy()[0]
not_paraphrase_results = tf.nn.softmax(not_paraphrase_classification_logits, axis=1).numpy()[0] print("Should be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%") print("\nShould not be paraphrase")
for i in range(len(classes)):
print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")

这将输出以下结果:

Should be paraphrase
not paraphrase: 10%
is paraphrase: 90% Should not be paraphrase
not paraphrase: 94%
is paraphrase: 6%

抽取式问答

抽取式问答是从给定问题的文本中抽取答案的任务。问答数据集的一个例子是SQuAD数据集,它完全基于该任务。如果你想在团队任务中微调模型,可以利用run_SQuAD.py

下面是一个使用管道进行问答的示例:从给定问题的文本中提取答案。它利用了一个小队的微调模型。

from transformers import pipeline

nlp = pipeline("question-answering")

context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
a model on a SQuAD task, you may leverage the `run_squad.py`.
""" print(nlp(question="What is extractive question answering?", context=context))
print(nlp(question="What is a good example of a question answering dataset?", context=context))

这将返回从文本中提取的答案,一个置信度,以及“开始”和“结束”值,这些值是提取的答案在文本中的位置。

{'score': 0.622232091629833, 'start': 34, 'end': 96, 'answer': 'the task of extracting an answer from a text given a question.'}
{'score': 0.5115299158662765, 'start': 147, 'end': 161, 'answer': 'SQuAD dataset,'}

下面是一个使用模型和Tokenizer回答问题的示例。该过程如下:

  • 从checkpoint名称实例化一个tokenizer和一个模型。该模型被识别为一个BERT模型,并用存储在checkpoint中的权重加载它。
  • 定义一段文本和几个问题。
  • 遍历问题并根据文本和当前问题构建一个序列,使用正确的模型特定分隔符标记类型id和注意力掩码将此序列传递到模型中。这将输出整个序列标记(问题和文本)的开始位置和结束位置的一系列分数。
  • 计算结果的softmax以获得从标记的开始位置和停止位置对应的概率
  • 将这些标记转换为字符串。
  • 打印结果

Pytorch代码

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") text = r"""

Transformers 库常见的用例 | 三的更多相关文章

  1. Java避坑宝典《Java业务开发常见错误100例》上线了

    写这个专栏的缘起 之前我写过一篇博客:<朱晔的互联网架构实践心得S2E2:写业务代码最容易掉的10种坑>,引起的关注还是挺多的.后来和极客时间的编辑一拍即合决定以这个为题写一个专栏.其实所 ...

  2. shell常见脚本30例

    shell常见脚本30例 author:headsen chen  2017-10-19  10:12:12 本文原素材出自网上,特此申明.有些地方加入我自己的改动 常见的30例shell脚本 1.用 ...

  3. {MySQL的库、表的详细操作}一 库操作 二 表操作 三 行操作

    MySQL的库.表的详细操作 MySQL数据库 本节目录 一 库操作 二 表操作 三 行操作 一 库操作 1.创建数据库 1.1 语法 CREATE DATABASE 数据库名 charset utf ...

  4. 常见MIME类型例表

    常见MIME类型例表: 序号 内容类型 文件扩展名 描述 1 application/msword doc Microsoft Word 2 application/octet-stream bin ...

  5. 常见的装包的三种宝,包 bao-devel bao-utils bao-agent ,包 开发包 工具包 客户端

    常见的装包的三种宝,包  bao-devel    bao-utils   bao-agent  ,包    开发包   工具包  客户端

  6. VC中引用第三方库,常见的库冲突问题

    Q:VC中引用第三方库,常见的库冲突问题 环境:[1]VS2008 [2]WinXP SP3 A1(方法一): [S1]第三方库(Binary形式的)如果同主程序冲突,则下载第三方库的源码[S2]保持 ...

  7. TCP常见的定时器及三次握手与四次挥手

    1.TCP常见的定时器 在TCP协议中有的时候需要定期或者按照某个算法对某个事件进行触发,那么这个时候,TCP协议是使用定时器进行实现的.在TCP中,会有七种定时器: 建立连接定时器(connecti ...

  8. [WPF自定义控件库]以Button为例谈谈如何模仿Aero2主题

    1. 为什么选择Aero2 除了以外观为卖点的控件库,WPF的控件库都默认使用"素颜"的外观,然后再提供一些主题包.这样做的最大好处是可以和原生控件或其它控件库兼容,而且对于大部分 ...

  9. 关于阿里图标库Iconfont生成图标的三种使用方式(fontclass/unicode/symbol)

    1.附阿里图标库链接:http://www.iconfont.cn/ 2.登录阿里图标库以后,搜索我们需要的图标,将其加入购物车,如图3.将我们需要的图标全部挑选完毕以后,点击购物车图标4.这时候右侧 ...

随机推荐

  1. 从2019-nCoV趋势预测问题,联想到关于网络安全态势预测问题的讨论

    0. 引言 在这篇文章中,笔者希望和大家讨论一个话题,即未来趋势是否可以被精确或概率性地预测. 对笔者所在的网络安全领域来说,由于网络攻击和网络入侵常常变现出随机性.非线性性的特征,因此纯粹的未来预测 ...

  2. CSS——NO.5(格式化排版)

    */ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:text.cpp * 作者:常轩 * 微信公众号:Worldhe ...

  3. js数据类型大全

    声明变量的命名规范(标识符) 1.不能以数字开头,只能以字母或者¥或者_开头 2.js变量名称区分大小写 3.变量名不能含有关键字(this.if.for.while) 4.驼峰命名法 console ...

  4. Cisco模拟器的基本使用

    获取帮助查找命令 只需输入一个'?'便可得到详细的帮助信息,如果想获取c开头的命名,那么直接输入'c?'即可. 在各个模式下切换的方法 给如图所示路由器接口配置IP地址 第一步:安装HWIC-2T(串 ...

  5. CSS实现响应式布局

    用CSS实现响应式布局 响应式布局感觉很高大上,很难,但实际上只用CSS也能实现响应式布局要用的就是CSS中的没接查询,下面就介绍一下怎么运用: 使用@media 的三种方法 1.直接在CSS文件中使 ...

  6. http请求缓存头详解

    缓存的作用:1.减少延迟(页面打开的速度).2.降低服务器负载(先取缓存,无缓存在请求服务器,有效降低服务器的负担).3.保证稳定性(有个笑话是手机抢购时为了保证服务器的稳定性,在前端写个随机数限制百 ...

  7. 2020年,如何成为一名 iOS 开发高手!

    2020年对应程序员来说,是一个多灾的年份,很多公司都进行了不同比例的优化和裁员.等疫情得到控制后,将会是找工作的高峰期,从去年的面试经历来看,现在只会单纯写业务代码的人找工作特别难,很多大厂的面试官 ...

  8. JavaScript的函数(三)

    函数也是对象,拥有属性和方法,就类似普通对象那样.1,length属性 arguments.lenght 表示传入实参的个数. 函数的length属性时只读属性,代表形参的个数.可以用argument ...

  9. 通过欧拉计划学Rust编程(第500题)

    由于研究Libra等数字货币编程技术的需要,学习了一段时间的Rust编程,一不小心刷题上瘾. "欧拉计划"的网址: https://projecteuler.net 英文如果不过关 ...

  10. 使用selenium模拟登陆淘宝、新浪和知乎

    如果直接使用selenium访问淘宝.新浪和知乎这些网址.一般会识别出这是自动化测试工具,会有反制措施.当开启开发者模式后,就可以绕过他们的检测啦.(不行的,哭笑) 如果网站只是对windows.na ...