将TVM集成到PyTorch上
将TVM集成到PyTorch上
随着TVM不断展示出对深度学习执行效率的改进,很明显PyTorch将从直接利用编译器堆栈中受益。PyTorch的主要宗旨是提供无缝且强大的集成,而这不会妨碍用户。为此,PyTorch现在具有基于TVM的官方后端torch_tvm。
用法很简单:
import
torch_tvm
torch_tvm.enable()
PyTorch将尝试在其JIT编译过程中,将所有可能的运算符转换为已知的Relay运算符。
背景
与许多其他ML框架不同,PyTorch公开了一个渴望执行的编程接口。这种编程风格避免了基于图的元编程,而专注于以Python方式直接控制n维数组(张量)。因此,该框架最初非常适合模型的试验和开发,但不适用于自动性能优化或部署。为了利用优化的编译器技术,PyTorch引入了一些较大的更改来解决此问题。
PyTorch 1.0引入了PyTorch IR,PyTorch专用的中间表示形式,用于类似于Relay的模型。可以通过模型跟踪将PyTorch程序转换为IR,该跟踪记录模型或Python的子集TorchScript的执行。新的TVM后端将PyTorch的IR降低到了Relay,并能够透明地提高PyTorch的性能,而无需用户参与。
整合与结果
为了支持Relay,PyTorch JIT添加了两个功能:自定义转换过程和自定义子图解释器。
当torch_tvm启用时,可以转换到中继PyTorch IR的子图Expr旨意被标记为继电器兼容。由于PyTorch IR并不总是包含形状信息,因此在调用之前,无法以有用的方式编译任何子图。
在用户调用期间,PyTorch JIT运行时将确定输入形状信息,并使用新的Relay C ++构建系统编译先前标记的子图。根据输入形状来缓存编译,以供后续运行。可以在README中找到更多详细信息。
torch_tvm建立了一个连续的基准测试系统,该系统正在监视ResNet18在CPU上的性能。对于各种ResNet型号,TVM的性能都是默认PyTorch
JIT后端的两倍以上。在AWS c5n.4xlarge实例上使用16个线程实现的每秒迭代次数(越大越好)。
这些结果令人鼓舞,该项目将继续致力于,在更多模型上提高CPU推理速度。
未来的工作
现在,PyTorch JIT进行了大量工作来查找其IR的纯功能子集,以馈送到Relay。这避免了将别名和控制流信息映射到中继的需要,但这不是必需的。将更多的PyTorch IR映射到Relay可能会取得性能上的胜利,这是该项目的目标。PyTorch IR在开发过程中正在迅速变化,因此必须谨慎进行。
将做更多的工作来确保PyTorch和TVM代码之间的切换是有效的。这包括统一线程模型,分配器以及减少与将输入复制到TVM相关的开销。
解析
如果已经编写了PyTorch模型,最简单的入门方法就是使用torch.jit.trace以下方法
import
torch_tvm
from
your_model import model, inputs
torch_tvm.enable(opt_level=3)
iters
= 100
warmup
= 10
#
Ensure your model is in eval mode and also turn off gradients.
with
torch.no_grad():
# Use
tuned parameters for better performance.
with
autotvm.apply_history_best("test/autotvm_tuning.log"):
# This is where all the compilation
happens.
trace_tvm = torch.jit.trace(model, inputs)
# Warmup
for _ in range(warmup):
_ =
trace_tvm(*inputs)
# Benchmark
start = time.time()
for _ in range(iters):
_ = trace_tvm(*inputs)
tvm_time = time.time() - start
print("Took {}s to run {}
iters".format(tvm_time, iters))
这段代码大部分来自Benchmarks.py。请注意,用于AVX2 LLVM编译的调整参数位于存储库test/文件夹中。
如果更直接使用Relay,可以通过(隐式)跟踪或TorchScript,直接从PyTorch函数中提取表达式:
def
add(a, b, c):
return a + b + c
#
via tracing
relay_graph
= torch_tvm.to_relay(add, inputs)
@torch.jit.script
def
mul(a, b, c):
return a * b * c
#
via script
relay_graph
= torch_tvm.to_relay(mul, inputs)
将TVM集成到PyTorch上的更多相关文章
- 将TVM集成到PyTorch
将TVM集成到PyTorch 随着TVM不断展示出对深度学习执行效率的改进,很明显PyTorch将从直接利用编译器堆栈中受益.PyTorch的主要宗旨是提供无缝且强大的集成,而这不会妨碍用户.PyTo ...
- 如何在TVM上集成Codegen(上)
如何在TVM上集成Codegen(上) 许多常用的深度学习内核,或者提供DNNL或TensorRT等框架和图形引擎,让用户以某种方式描述他们的模型,从而获得高性能.此外,新兴的深度学习加速器也有自己的 ...
- [转载]PyTorch上的contiguous
[转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...
- 在Pytorch上使用稀疏矩阵
在Pytorch上使用稀疏矩阵 最近在写一个NLP的小项目,用到了Pytorch做神经网络模型.但是众所周知NLP的一个特点就是特征矩阵是稀疏矩阵,当时处理稀疏矩阵用的是scipy.sparse,现在 ...
- TVM 优化 ARM GPU 上的移动深度学习
TVM 优化 ARM GPU 上的移动深度学习 随着深度学习的巨大成功,将深度神经网络部署到移动设备的需求正在迅速增长.与桌面平台上所做的类似,在移动设备中使用 GPU 既有利于推理速度,也有利于能源 ...
- TVM在ARM GPU上优化移动深度学习
TVM在ARM GPU上优化移动深度学习 随着深度学习的巨大成功,将深度神经网络部署到移动设备的需求正在迅速增长.与在台式机平台上所做的类似,在移动设备中使用GPU可以提高推理速度和能源效率.但是,大 ...
- 【微服务专题之】.Net6下集成消息队列上-RabbitMQ
微信公众号:趣编程ACE关注可了解更多的.NET日常实战开发技巧,如需源码 请公众号后台留言 源码;[如果觉得本公众号对您有帮助,欢迎关注] .Net中RabbitMQ的使用 [微服务专题之].N ...
- Liferay7 BPM门户开发之45: 集成Activiti文件上传部署流程BPMN模型
开发文件上传,部署流程模板. 首先,开发jsp页面,deploy.jsp <%@ include file="/init.jsp" %> <h3>${RET ...
- iOS支付宝,微信,银联支付集成封装(上)
一.集成支付宝支付 支付宝集成官方教程https://docs.open.alipay.com/204/105295/ 支付宝集成官方demo https://docs.open.alipay.com ...
随机推荐
- day-26-封装-property装饰器-反射
一.super进阶 在多继承中:严格按照mro顺序来执行 super是按照mro顺序来寻找当前类的下一类 在py3中不需要传参数,自动就帮我们寻找当前类的mro顺序的下一个类中的同名方法 在py2中的 ...
- 【Scrapy(一)】 Scrapy爬虫的基础执行流程
安装scrapy模块 : pip install scrapy 创建scrapy项目 1.scrapy startprojecty 项目名称 注意:如果创建失败,可以先卸载原有的scrapy模块, ...
- LA3027简单带权并查集
题意: 有n个点,一开始大家都是独立的点,然后给出一些关系,a,b表示a是b的父亲节点,距离是abs(a-b)%1000,然后有一些询问,每次询问一个节点a到父亲节点的距离是多少? 思路: ...
- Win64 驱动内核编程-10.突破WIN7的PatchGuard
突破WIN7的PatchGuard WIN64 有两个内核保护机制,KPP 和 DSE.KPP 阻止我们 PATCH 内核,DSE 拦截我们加载驱动.当然 KPP 和 DSE 并不是不可战胜的,WIN ...
- Filter过滤器的基本使用方法
ProjectDescription Filter的使用 创建类实现javax.servlet.Filter. 重写方法: init(); //过滤器初始化 doFilter(); //过滤请求 1. ...
- js--吐血总结最近遇到的变态表单校验---element+原生+jq+easyUI(前端职业生涯见过的最烦的校验)
最近写了无数各种形式的表单,记录下奇奇怪怪的校验规则~ 一:首先是element自带的rules校验规则: element作为常用框架,自带rules属性简单易懂,官方文档一目了然,不再赘述,应付常用 ...
- 软件测试中的测试用例Test Case原来是这么回事!
如果你去找一份功能测试的工作,在软件测试工程师面试过程中,有一些面试官会来一两个非常简单的问题 什么是Test Case?你是如何去写Test Case的? 我们先来看一下测试用例的介绍 什么是测试用 ...
- LINQ之查询语法
新开一节LINQ的入门讲解. LINQ(Language Integrated Query)语言集成查询,是C#语言的扩展,它的主要功能是从数据集中查询数据,就像通过sql语句从数据库查询数据一样(本 ...
- Linux下线程pid和tid
#include <stdio.h> #include <pthread.h> #include <sys/types.h> #include <sys/sy ...
- Java初始化数据域的途径
Java调用构造器的具体处理步骤: 1.所有数据域被初始化为默认值(0,false或null); 2.按照在类声明中出现的次序,依次执行所有域的初始化语句和初始化块: 3.如果构造器第一行调用了第二个 ...