Transformers 中使用 TorchScript | 四
作者|huggingface
编译|VK
来源|Github
注意:这是我们使用TorchScript进行实验的开始,我们仍在探索可变输入大小模型的功能。它是我们关注的焦点,我们将在即将发布的版本中加深我们的分析,提供更多代码示例,更灵活的实现以及将基于python的代码与已编译的TorchScript进行基准测试的比较。
根据Pytorch的文档:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法”。Pytorch的两个模块JIT和TRACE允许开发人员导出他们的模型,这些模型可以在其他程序中重用,例如面向效率的C++程序。
我们提供了一个接口,该接口允许将transformers模型导出到TorchScript,以便他们可在与基于Pytorch的python程序不同的环境中重用。在这里,我们解释如何使用我们的模型,以便可以导出它们,以及将这些模型与TorchScript一起使用时要注意的事项。
导出模型需要两件事:
- 虚拟化输入以执行模型正向传播。
- 需要使用
torchscript标志实例化该模型。
这些必要性意味着开发人员应注意几件事。这些在下面详细说明。
含义
TorchScript标志和解绑权重
该标志是必需的,因为该存储库中的大多数语言模型都在它们的Embedding层及其Decoding层具有绑定权重关系。TorchScript不允许导出绑定权重的模型,因此,有必要事先解绑权重。
这意味着以torchscript标志实例化的模型使得Embedding层和Decoding层分开,这意味着不应该对他们进行同时训练,导致意外的结果。
对于没有语言模型头(Language Model head)的模型,情况并非如此,因为那些模型没有绑定权重。这些型号可以在没有torchscript标志的情况下安全地导出。
虚拟(dummy)输入和标准长度
虚拟输入用于进行模型前向传播。当输入值在各层中传播时,Pytorch跟踪在每个张量上执行的不同操作。然后使用这些记录的操作创建模型的“迹"。
迹是相对于输入的尺寸创建的。因此,它受到虚拟输入尺寸的限制,并且不适用于任何其他序列长度或批次大小。尝试使用其他尺寸时,会出现如下错误,如:
The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2
因此,建议使用至少与最大输入大小相同的虚拟输入大小来跟踪模型。在推理期间对于模型的输入,可以执行填充来填充缺少的值。作为模型
将以较大的输入大小来进行跟踪,但是,不同矩阵的尺寸也将很大,从而导致更多的计算。
建议注意每个输入上完成的操作总数,并密切关注各种序列长度对应性能的变化。
在Python中使用TorchScript
以下是使用Python保存,加载模型以及如何使用"迹"进行推理的示例。
保存模型
该代码段显示了如何使用TorchScript导出BertModel。在这里实例化BertModel,根据BertConfig类,然后以文件名traced_bert.pt保存到磁盘
from transformers import BertModel, BertTokenizer, BertConfig
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# 标记输入文本
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# 输入标记之一进行掩码
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# 创建虚拟输入
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# 使用torchscript标志初始化模型
# 标志被设置为True,即使没有必要,因为该型号没有LM Head。
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)
# 实例化模型
model = BertModel(config)
# 模型设置为评估模式
model.eval()
# 如果您要使用from_pretrained实例化模型,则还可以设置TorchScript标志
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# 创建迹
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
载入模型
该代码段显示了如何加载以前以名称traced_bert.pt保存到磁盘的BertModel。
我们重新使用之前初始化的dummy_input。
loaded_model = torch.jit.load("traced_model.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(dummy_input)
使用跟踪模型进行推理
使用跟踪模型进行推理就像使用其__call__ 方法一样简单:
traced_model(tokens_tensor, segments_tensors)
原文链接:https://huggingface.co/transformers/torchscript.html
欢迎关注磐创AI博客站:
http://panchuang.net/
OpenCV中文官方文档:
http://woshicver.com/
欢迎关注磐创博客资源汇总站:
http://docs.panchuang.net/
Transformers 中使用 TorchScript | 四的更多相关文章
- C#中的线程四(System.Threading.Thread)
C#中的线程四(System.Threading.Thread) 1.最简单的多线程调用 System.Threading.Thread类构造方法接受一个ThreadStart委托,改委托不带参数,无 ...
- ORACLE中CONSTRAINT的四对属性
ORACLE中CONSTRAINT的四对属性 summary:在data migrate时,某些表的约束总是困扰着我们,让我们的migratet举步维艰,怎样利用约束本身的属性来处理这些问题呢?本文具 ...
- iOS中常用的四种数据持久化方法简介
iOS中常用的四种数据持久化方法简介 iOS中的数据持久化方式,基本上有以下四种:属性列表.对象归档.SQLite3和Core Data 1.属性列表涉及到的主要类:NSUserDefaults,一般 ...
- js中this的四种使用方法
0x00:js中this的四种调用模式 1,方法调用模式 2,函数调用模式 3,构造器调用模式 4,apply.call.bind调用模式 0x01:第一种:方法调用模式 (也就是用.调用的)this ...
- {Django基础八之cookie和session}一 会话跟踪 二 cookie 三 django中操作cookie 四 session 五 django中操作session
Django基础八之cookie和session 本节目录 一 会话跟踪 二 cookie 三 django中操作cookie 四 session 五 django中操作session 六 xxx 七 ...
- C#中datagridviewz中SelectionMode的四个属性的含义
C#中datagridviewz中SelectionMode的四个属性的含义 DataGridViewSelectionMode.ColumnHeaderSelect 单击列头就可以选择整列DataG ...
- Django中模型(四)
Django中模型(四) 五.创建对象 1.目的 向数据库中添加数据.当创建对象时,Django不会对数据库进行读写操作,当调用save()方法时,才与数据库交互,将对象保存到数据库中 2.注意 __ ...
- Android中Activity的四种启动方式
谈到Activity的启动方式必须要说的是数据结构中的栈.栈是一种只能从一端进入存储数据的线性表,它以先进后出的原则存储数据,先进入的数据压入栈底,后进入的数据在栈顶.需要读取数据的时候就需要从顶部开 ...
- Python实现接口测试中的常见四种Post请求数据
前情: 在日常的接口测试工作中,模拟接口请求通常有两种方法, 利用工具来模拟,比如fiddler,postman,poster,soapUI等 利用代码来模拟,使用到一些网络模块,比如HttpClie ...
随机推荐
- 算发帖——俄罗斯方块覆盖问题一共有多少个解
问题的提出:如下图,用13块俄罗斯方块覆盖8*8的正方形. 那么一共可以有多少个解呢?(若通过旋转.翻转一个解而得到的新解,则两个解视为同一个解) 首先,求解的问题,已经在上一篇帖子里完成 算 ...
- NumPy——统计函数
引入模块import numpy as np 1.numpy.sum(a, axis=None)/a.sum(axis=None) 根据给定轴axis计算数组a相关元素之和,axis整数或元组,不指定 ...
- python大佬养成计划----HTML网页设计(序列)
序列化标签 1.有序标签--ol和li 有序列表标签是<ol>,是一个双标签.在每一个列表项目前要使用<li>标签.<ol>标签的形式是带有前后顺序之分的编号.如果 ...
- 逐行分析jQuery2.0.3源码-完整笔记
概览 (function (){ (21 , 94) 定义了一些变量和函数 jQuery=function(); (96 , 293) 给jQuery对象添加一些方法和属性; (285 , 347) ...
- 浏览器渲染流程&Composite(渲染层合并)简单总结
梳理浏览器渲染流程 首先简单了解一下浏览器请求.加载.渲染一个页面的大致过程: DNS 查询 TCP 连接 HTTP 请求即响应 服务器响应 客户端渲染 这里主要将客户端渲染展开梳理一下,从浏览器器内 ...
- 打造你的第一个 Electron 应用
Electron 可以让你使用纯 JavaScript 调用丰富的原生(操作系统) APIs 来创造桌面应用. 你可以把它看作一个 Node. js 的变体,它专注于桌面应用而不是 Web 服务器端. ...
- 单片机基础——使用GPIO扫描检测按键
1. 准备工作 硬件准备 开发板首先需要准备一个小熊派IoT开发板,并通过USB线与电脑连接. 软件准备 需要安装好Keil - MDK及芯片对应的包,以便编译和下载生成的代码,可参考MDK安装教程 ...
- mysql 学习日记 悲观和乐观锁
理解 悲观锁就是什么事情都是需要小心翼翼,生怕弄错了出大问题, 一般情况下 "增删改" 都是有事务在进行操作的,但是 "查" 是不需要事务操作的, 但是凡事没 ...
- freecplus框架,Linux平台下C/C++程序员提高开发效率的利器
目录 一.freecplus框架简介 二.freecplus开源许可协议 三.freecplus框架内容 字符串操作 2.xml解析 3.日期时间 4.目录操作 5.文件操作 6.日志文件 7.参数文 ...
- Spring Boot从入门到精通(九)整合Spring Data JPA应用框架
JPA是什么? JPA全称Java Persistence API,是Sun官方提出的Java持久化规范.是JDK 5.0注解或XML描述对象-关系表的映射关系,并将运行期的实体对象持久化到数据库中. ...