作者|huggingface

编译|VK

来源|Github

注意:这是我们使用TorchScript进行实验的开始,我们仍在探索可变输入大小模型的功能。它是我们关注的焦点,我们将在即将发布的版本中加深我们的分析,提供更多代码示例,更灵活的实现以及将基于python的代码与已编译的TorchScript进行基准测试的比较。

根据Pytorch的文档:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法”。Pytorch的两个模块JIT和TRACE允许开发人员导出他们的模型,这些模型可以在其他程序中重用,例如面向效率的C++程序。

我们提供了一个接口,该接口允许将transformers模型导出到TorchScript,以便他们可在与基于Pytorch的python程序不同的环境中重用。在这里,我们解释如何使用我们的模型,以便可以导出它们,以及将这些模型与TorchScript一起使用时要注意的事项。

导出模型需要两件事:

  • 虚拟化输入以执行模型正向传播。
  • 需要使用torchscript标志实例化该模型。

这些必要性意味着开发人员应注意几件事。这些在下面详细说明。

含义

TorchScript标志和解绑权重

该标志是必需的,因为该存储库中的大多数语言模型都在它们的Embedding层及其Decoding层具有绑定权重关系。TorchScript不允许导出绑定权重的模型,因此,有必要事先解绑权重。

这意味着以torchscript标志实例化的模型使得Embedding层和Decoding层分开,这意味着不应该对他们进行同时训练,导致意外的结果。

对于没有语言模型头(Language Model head)的模型,情况并非如此,因为那些模型没有绑定权重。这些型号可以在没有torchscript标志的情况下安全地导出。

虚拟(dummy)输入和标准长度

虚拟输入用于进行模型前向传播。当输入值在各层中传播时,Pytorch跟踪在每个张量上执行的不同操作。然后使用这些记录的操作创建模型的“迹"。

迹是相对于输入的尺寸创建的。因此,它受到虚拟输入尺寸的限制,并且不适用于任何其他序列长度或批次大小。尝试使用其他尺寸时,会出现如下错误,如:

The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2

因此,建议使用至少与最大输入大小相同的虚拟输入大小来跟踪模型。在推理期间对于模型的输入,可以执行填充来填充缺少的值。作为模型

将以较大的输入大小来进行跟踪,但是,不同矩阵的尺寸也将很大,从而导致更多的计算。

建议注意每个输入上完成的操作总数,并密切关注各种序列长度对应性能的变化。

在Python中使用TorchScript

以下是使用Python保存,加载模型以及如何使用"迹"进行推理的示例。

保存模型

该代码段显示了如何使用TorchScript导出BertModel。在这里实例化BertModel,根据BertConfig类,然后以文件名traced_bert.pt保存到磁盘

from transformers import BertModel, BertTokenizer, BertConfig
import torch enc = BertTokenizer.from_pretrained("bert-base-uncased") # 标记输入文本
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text) # 输入标记之一进行掩码
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] # 创建虚拟输入
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors] # 使用torchscript标志初始化模型
# 标志被设置为True,即使没有必要,因为该型号没有LM Head。
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True) # 实例化模型
model = BertModel(config) # 模型设置为评估模式
model.eval() # 如果您要​​使用from_pretrained实例化模型,则还可以设置TorchScript标志
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) # 创建迹
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

载入模型

该代码段显示了如何加载以前以名称traced_bert.pt保存到磁盘的BertModel

我们重新使用之前初始化的dummy_input

loaded_model = torch.jit.load("traced_model.pt")
loaded_model.eval() all_encoder_layers, pooled_output = loaded_model(dummy_input)

使用跟踪模型进行推理

使用跟踪模型进行推理就像使用其__call__ 方法一样简单:

traced_model(tokens_tensor, segments_tensors)

原文链接:https://huggingface.co/transformers/torchscript.html

欢迎关注磐创AI博客站:

http://panchuang.net/

OpenCV中文官方文档:

http://woshicver.com/

欢迎关注磐创博客资源汇总站:

http://docs.panchuang.net/

Transformers 中使用 TorchScript | 四的更多相关文章

  1. C#中的线程四(System.Threading.Thread)

    C#中的线程四(System.Threading.Thread) 1.最简单的多线程调用 System.Threading.Thread类构造方法接受一个ThreadStart委托,改委托不带参数,无 ...

  2. ORACLE中CONSTRAINT的四对属性

    ORACLE中CONSTRAINT的四对属性 summary:在data migrate时,某些表的约束总是困扰着我们,让我们的migratet举步维艰,怎样利用约束本身的属性来处理这些问题呢?本文具 ...

  3. iOS中常用的四种数据持久化方法简介

    iOS中常用的四种数据持久化方法简介 iOS中的数据持久化方式,基本上有以下四种:属性列表.对象归档.SQLite3和Core Data 1.属性列表涉及到的主要类:NSUserDefaults,一般 ...

  4. js中this的四种使用方法

    0x00:js中this的四种调用模式 1,方法调用模式 2,函数调用模式 3,构造器调用模式 4,apply.call.bind调用模式 0x01:第一种:方法调用模式 (也就是用.调用的)this ...

  5. {Django基础八之cookie和session}一 会话跟踪 二 cookie 三 django中操作cookie 四 session 五 django中操作session

    Django基础八之cookie和session 本节目录 一 会话跟踪 二 cookie 三 django中操作cookie 四 session 五 django中操作session 六 xxx 七 ...

  6. C#中datagridviewz中SelectionMode的四个属性的含义

    C#中datagridviewz中SelectionMode的四个属性的含义 DataGridViewSelectionMode.ColumnHeaderSelect 单击列头就可以选择整列DataG ...

  7. Django中模型(四)

    Django中模型(四) 五.创建对象 1.目的 向数据库中添加数据.当创建对象时,Django不会对数据库进行读写操作,当调用save()方法时,才与数据库交互,将对象保存到数据库中 2.注意 __ ...

  8. Android中Activity的四种启动方式

    谈到Activity的启动方式必须要说的是数据结构中的栈.栈是一种只能从一端进入存储数据的线性表,它以先进后出的原则存储数据,先进入的数据压入栈底,后进入的数据在栈顶.需要读取数据的时候就需要从顶部开 ...

  9. Python实现接口测试中的常见四种Post请求数据

    前情: 在日常的接口测试工作中,模拟接口请求通常有两种方法, 利用工具来模拟,比如fiddler,postman,poster,soapUI等 利用代码来模拟,使用到一些网络模块,比如HttpClie ...

随机推荐

  1. 算发帖——俄罗斯方块覆盖问题一共有多少个解

    问题的提出:如下图,用13块俄罗斯方块覆盖8*8的正方形.   那么一共可以有多少个解呢?(若通过旋转.翻转一个解而得到的新解,则两个解视为同一个解)   首先,求解的问题,已经在上一篇帖子里完成 算 ...

  2. NumPy——统计函数

    引入模块import numpy as np 1.numpy.sum(a, axis=None)/a.sum(axis=None) 根据给定轴axis计算数组a相关元素之和,axis整数或元组,不指定 ...

  3. python大佬养成计划----HTML网页设计(序列)

    序列化标签 1.有序标签--ol和li 有序列表标签是<ol>,是一个双标签.在每一个列表项目前要使用<li>标签.<ol>标签的形式是带有前后顺序之分的编号.如果 ...

  4. 逐行分析jQuery2.0.3源码-完整笔记

    概览 (function (){ (21 , 94) 定义了一些变量和函数 jQuery=function(); (96 , 293) 给jQuery对象添加一些方法和属性; (285 , 347) ...

  5. 浏览器渲染流程&Composite(渲染层合并)简单总结

    梳理浏览器渲染流程 首先简单了解一下浏览器请求.加载.渲染一个页面的大致过程: DNS 查询 TCP 连接 HTTP 请求即响应 服务器响应 客户端渲染 这里主要将客户端渲染展开梳理一下,从浏览器器内 ...

  6. 打造你的第一个 Electron 应用

    Electron 可以让你使用纯 JavaScript 调用丰富的原生(操作系统) APIs 来创造桌面应用. 你可以把它看作一个 Node. js 的变体,它专注于桌面应用而不是 Web 服务器端. ...

  7. 单片机基础——使用GPIO扫描检测按键

    1. 准备工作 硬件准备 开发板首先需要准备一个小熊派IoT开发板,并通过USB线与电脑连接. 软件准备 需要安装好Keil - MDK及芯片对应的包,以便编译和下载生成的代码,可参考MDK安装教程 ...

  8. mysql 学习日记 悲观和乐观锁

    理解  悲观锁就是什么事情都是需要小心翼翼,生怕弄错了出大问题, 一般情况下 "增删改" 都是有事务在进行操作的,但是 "查" 是不需要事务操作的, 但是凡事没 ...

  9. freecplus框架,Linux平台下C/C++程序员提高开发效率的利器

    目录 一.freecplus框架简介 二.freecplus开源许可协议 三.freecplus框架内容 字符串操作 2.xml解析 3.日期时间 4.目录操作 5.文件操作 6.日志文件 7.参数文 ...

  10. Spring Boot从入门到精通(九)整合Spring Data JPA应用框架

    JPA是什么? JPA全称Java Persistence API,是Sun官方提出的Java持久化规范.是JDK 5.0注解或XML描述对象-关系表的映射关系,并将运行期的实体对象持久化到数据库中. ...