pipelines 是使用模型进行推理的一种很好且简单的方法。这些pipelines 是从库中抽象出大部分复杂代码的对象,提供了一个简单的API,专门用于多个任务,包括命名实体识别、屏蔽语言建模、情感分析、特征提取和问答等。

参数说明

初始化pipeline时可能的参数:

task (str) — 定义pipeline需要返回的任务。

model (str or PreTrainedModel or TFPreTrainedModel, optional) — 拟使用的模型,有时可以只指定模型,不指定task

config (str or PretrainedConfig, optional) — 实例化模型的配置。取值可以是一个模型标志符(模型名称),也可以是利用PretrainedConfig继承得来

tokenizer (str or PreTrainedTokenizer, optional) — 用于编码模型中的数据。取值可以是一个模型标志符(模型名称),也可以是利用 PreTrainedTokenizer继承得来

feature_extractor (str or PreTrainedFeatureExtractor, optional) — 特征提取器

framework (str, optional) — 指明运行模型的框架,要么是"pt"(表示pytorch), 要么是"tf"(表示tensorflow)

revision (str, optional, defaults to "main") — 指定所加载模型的版本

use_fast (bool, optional, defaults to True) — 如果可以的话(a PreTrainedTokenizerFast),是否使用Fast tokenizer

use_auth_token (str or bool, optional) — 是否需要认证

device (int or str or torch.device) — 指定运行模型的硬件设备。(例如:"cpu","cuda:1","mps",或者是一个GPU的编号,比如 1)

device_map (str or Dict[str, Union[int, str, torch.device], optional) — Sent directly as model_kwargs (just a simpler shortcut). When accelerate library is present, set device_map="auto" to compute the most optimized device_map automatically. More information

torch_dtype (str or torch.dtype, optional) — 指定模型可用的精度。sent directly as model_kwargs (just a simpler shortcut) to use the available precision for this model (torch.float16, torch.bfloat16, … or "auto").

trust_remote_code (bool, optional, defaults to False) —

使用pipeline对象处理数据可能的参数:

batch_size (int) — 数据处理的批次大小

truncation (bool, optional, defaults to False) — 是否截断

padding (bool, optional, defaults to False) — 是否padding

实例:

初始化一个文本分类其的pipeline 对象

from transformers import pipeline
classifier = pipeline(task="text-classification") #

模型的输入inputs(可以是一个字典、列表、单个字符串)。

inputs = "嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻"

使用pipeline对象处理数据

results = classifier(inputs, truncation=True, padding=True, max_length=512):

批处理的使用建议

  • 在有延迟限制的实时任务中, 别用批处理
  • 使用CPU进行预测时,别用批处理
  • 如果您不知道sequence_length的大小(例如自然数据),别用批处理。设置OOM检查,以便于过长输入序列导致模型执行异常时,模型可以自动重启
    • 如果输入中包含100个样本序列,仅一个样本序列长度是600,其余长度为4,那么当它们作为一个批次输入时,输入数据的shape也仍是(100, 600)。
  • 如果样本sequence_length比较规整,则建议使用尽可能大的批次。
  • 总之,使用批处理需要处理更好溢出问题

Pipelines 主要包括三大模块(以TextClassificationPipeline为例)

数据预处理:tokenize 未处理的数据。对应pipeline中的preprocess()

前向计算:模型进行计算。对应pipeline中的_forward()

后处理:对模型计算出来的logits进行进一步处理。对应pipeline中的postprocess()

PS: 继承pipeline类(如TextClassificationPipeline),并重写以上的三个函数,可以实现自定义pipeline

实例(修改模型预测的标签):此处以修改模型预测标签为例,重写后处理过程postprocess()

1、导库,并读取提前准备好的标签映射数据

# coding=utf-8
from transformers import pipeline
from transformers import TextClassificationPipeline
import numpy as np
import json
import pandas as pd with open('model_save_epochs100_batch1/labelmap.json')as fr:
id2label = {ind: label for ind, label in enumerate(json.load(fr).values())}

2、自定义pipeline(继承TextClassificationPipeline,并重写postprocess()

class CustomTextClassificationPipeline(TextClassificationPipeline):

    def sigmoid_(self, _outputs):
return 1.0 / (1.0 + np.exp(-_outputs)) def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True): outputs = model_outputs["logits"][0] # 感觉这里每次只会返回一个样本的计算结果
outputs = outputs.numpy()
scores = self.sigmoid_(outputs)
dict_scores = [
{"label": id2label[i], "score": score.item()} for i, score in enumerate(scores) if score > 0.5
]
return dict_scores

3、使用自定义的pipeline实例化一个分类器(有两种方式)

方式一:将自定义类名传参给pipeline_class

classifier = pipeline(model='model_save_epochs100_batch1/checkpoint-325809',
pipeline_class=CustomTextClassificationPipeline,
task="text-classification",
function_to_apply='sigmoid', top_k=10, device=0) # return_all_scores=True

方式二:直接使用自定义类创建分类器

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained('model_save_epochs100_batch1/checkpoint-325809')

classifier = CustomTextClassificationPipeline(model=model, # 此处model的值得是加载好了得模型,不能是一个字符串
pipeline_class=CustomTextClassificationPipeline,
task="text-classification",
function_to_apply='sigmoid', top_k=10, device=0) # return_all_scores=True

4、利用分类器对文本数据进行预测

# 读取待测数据
with open('raw_data/diseasecontent.json', 'r', encoding='utf-8') as fr:
texts = [text.strip() for text in json.load(fr)if text.strip("000").strip()] res = []
for text in texts:
labels = []
for ite in classifier(text, truncation=True, max_length=512): # padding=True, 执行预测
labels.append(ite['label'])
res.append({'text': text, 'labels': labels}) # 保存预测结果
df = pd.DataFrame(res)
df.to_excel('model_save_epochs100_batch1/test_res.xlsx')

Transformers Pipelines的更多相关文章

  1. Spark2.0 Pipelines

    MLlib中众多机器学习算法API在单一管道或工作流中更容易相互结合起来使用.管道的思想主要是受到scikit-learn库的启发. ML API使用Spark SQL中的DataFrame作为机器学 ...

  2. kaggle Pipelines

    # Most scikit-learn objects are either transformers or models. # Transformers are for pre-processing ...

  3. ML Pipelines管道

    ML Pipelines管道 In this section, we introduce the concept of ML Pipelines. ML Pipelines provide a uni ...

  4. Nancy之Pipelines三兄弟(Before After OnError)

    一.简单描述 Before:如果返回null,拦截器将主动权转给路由:如果返回Response对象,则路由不起作用. After : 没有返回值,可以在这里修改或替换当前的Response. OnEr ...

  5. Coax Transformers[转载]

    Coax Transformers How to determine the needed Z for a wanted Quarter Wave Lines tranformation ratio ...

  6. 【最短路】ACdream 1198 - Transformers' Mission

    Problem Description A group of transformers whose leader is Optimus Prime(擎天柱) were assigned a missi ...

  7. Linux - 命令行 管道(Pipelines) 详细解释

    命令行 管道(Pipelines) 详细解释 本文地址: http://blog.csdn.net/caroline_wendy/article/details/24249529 管道操作符" ...

  8. 使用 Bitbucket Pipelines 持续交付托管项目

    简介 Bitbucket Pipelines 是Atlassian公司为Bitbucket Cloud产品添加的一个新功能, 它为托管在Bitbucket上的项目提供了一个良好的持续集成/交付的服务. ...

  9. Easy machine learning pipelines with pipelearner: intro and call for contributors

    @drsimonj here to introduce pipelearner – a package I'm developing to make it easy to create machine ...

  10. sql hibernate查询转换成实体或对应的VO Transformers

    sql查询转换成实体或对应的VO Transformers //addScalar("id") 默认查询出来的id是全部大写的(sql起别名也无效,所以使用.addScalar(& ...

随机推荐

  1. linux下redis_单机版_主从_集群_部署文档

    一 单机版部署 1.1 Redis下载地址 http://download.redis.io/releases/ 本次部署版本:3.2.8 当前最新版本:5.0.5 1.2 安装 部署路径说明规划 / ...

  2. CF803G Periodic RMQ Problem

    简要题意 你需要维护一个序列 \(a\),有 \(q\) 个操作,支持: 1 l r x 将 \([l,r]\) 赋值为 \(x\). 2 l r 询问 \([l,r]\) 的最小值. 为了加大难度, ...

  3. 前端Linux部署命令与流程记录

    以前写过一篇在Linux上从零开始部署前后端分离的Vue+Spring boot项目,但那时候是部署自己的个人项目,磕磕绊绊地把问题解决了,后来在公司有了几次应用到实际生产环境的经验,发现还有很多可以 ...

  4. JavaScript 中 this 关键字的作用和如何改变其上下文

    一.this 关键字的作用 JavaScript 中的 this 关键字引用了所在函数正在被调用时的对象.在不同的上下文中,this 的指向会发生变化. 在全局上下文中,this 指向全局对象(在浏览 ...

  5. 动力节点—day06

    常用类 String String表示字符串类型,属于引用数据类型,不属于基本数据类型 在Java中随便使用双引号括起来的都是String对象,例如"abc"."def& ...

  6. MySQL之字段约束条件

    MySQL之字段约束条件 一.MySQL之字段约束条件 1.unsigned 无符号 unsigned 为非负数,可以用此类增加数据长度 eg:tinyint最大是127,那tinyint unsig ...

  7. 练习:集合元素处理(传统方式)-练习:集合元素处理(Stream方式)

    练习:集合元素处理(传统方式) 题目 现在有两个ArrayList集合存储队伍当中的多个成员姓名,要求使用传统的for循环(或增强for循环依次进行以下若干操作步骤︰ 1.第一个队伍只要名字为3个字的 ...

  8. java 进阶P-2.3+P-2.4

    封闭的访问属性 private 封装:把数据和对数据的操作放在一起. (所谓封装就是把数据和对这些数据的操作放在一个地方,通过这些操作把这些数据保护起来,别人不能直接接触到这些数据) 1 privat ...

  9. 如何通过Java代码向Word文档添加文档属性

    Word文档属性包括常规.摘要.统计.内容.自定义.其中摘要包括标题.主题.作者.经理.单位.类别.关键词.备注等项目.属性相当于文档的名片,可以添加你想要的注释.说明等.还可以标注版权. 今天就为大 ...

  10. windows消息处理机制

    三层机制 1.顶端就是Windows内核.Windows内核维护着一个消息队列,第二级控制中心从这个消息队列中获取属于自己管辖的消息,后做出处理,有些消息直接处理掉,有些还要发送给下一级窗体(Wind ...