从一个例子看tvm执行流程
TVM整体流程(参考:TVM介绍)
机器学习模型在用TVM优化编译器框架进行变换时的步骤:
从Tensorflow/pytorch或ONNX等框架导入模型
import层是TVM从其他框架中导入模型的地方
注:TVM为每个前端提供的支持水平不尽相同,可尝试将模型转换为ONNX转换到Relay
Relay是TVM的高级模型语言,导入到TVM的模型是用Relay表示的。Relay是一种函数式语言(function language)和神经网络的中间表示法(intermediate representation,IR),它支持以下内容:
- 传统的数据流图式表示法
- Functional-style scoping 和 let-binding 使其成为一种功能齐全的可微分语言
- 能够允许用户混合两种编程风格
Relay应用图级(graph-level)优化passes来优化模型
- Lower to Tensor Expression (TE) representation
lower是指较高层表示转换成较低层表示。在high-level经过优化后,Relay 运行 FuseOps,将模型分割成许多小的子图,并将子图 lower 到 TE 表示。
TE 即Tensor Expression,张量表达,是描述张量计算的专属性语言
TE 还提供了一些schedule 原语来指定低级的循环优化,例如平铺(tiling)、矢量化(vectorization)、并行化(parallelization)、unrolling 和 fusion。
为了帮助将 Relay 表示转换为 TE 表示的过程,TVM 包含张量算子清单(Tensor Operator Inventory, TOPI),它有预先定义的常见张量算子的模板(如 conv2d、transpose)。
- Search for the best schedule using the auto-tuning module AutoTVM or AutoScheduler.
schedule 指定在 TE 中定义了算子或子图的低级循环优化。auto-tuning 模块搜索最佳 schedule 并将其与 cost 模型和设备上的测量结果进行比较。
在 TVM 中,有两个 auto-tuning 模块:
- AutoTVM: 基于模板的 auto-tuning 模块。它运行搜索算法为用户定义的模板中的可调节旋钮找到最佳值。对于常见的运算符,其模板已经在 TOPI 中提供。
- AutoScheduler (别名 Ansor) :无模板的auto-tuning 模块。它不需要预先定义的 schedule 模板。相反,它通过分析计算的定义自动生成搜索空间。然后,它在生成的搜索空间中搜索最佳 schedule。
Choose the optimal configurations for model compilation.
tuning 后,auto-tuning 模块会生成 JSON 格式的 auto-tuning 记录。这一步为每个子图挑选出最佳的 schedule。Lower to Tensor Intermediate Representation (TIR),TVM's low-level intermediate representation
TIR 是张量级的中间表示(Tensor Intermediate Representation),TVM 的低层次中间表示。
在根据 tuning 步骤选择最佳配置后,每个 TE 子图被降低到 TIR,并通过低级别的优化 passes 进行优化。
接下来,优化后的 TIR 被 lower 到硬件平台的目标编译器中。这是最后的代码生成阶段,产生可以部署到生产中的优化模型。
TVM 支持几种不同的编译器后端,包括:
- LVM:它可以针对任意的微处理器架构,包括 标准 x86 和 ARM 处理器,AMDGPU 和 NVPTX 代码生成,以及 LLVM 支持的任何其他平台。
- 专门的编译器,如 NVCC,NVIDIA 的编译器。
- 嵌入式和专用目标,通过 TVM 的 Bring Your Own Codegen(BYOC)框架实现。
- Compile down to machine code.
在这个过程结束时,特定的编译器生成的代码可以 lower 为机器码。
TVM 可以将模型编译成可链接的对象模块,然后可以用轻量级的 TVM 运行时来运行,该运行时提供 C 语言的 API 来动态加载模型,以及其他语言的入口,如 Python 和 Rust。TVM 还可以建立捆绑式部署,其中运行时与模型结合在一个包中。
例子 -- 编译Pytorch模型(参考:官网例子)
- 导入库
import tvm
from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
# PyTorch imports
import torch
import torchvision
- 加载pytorch预训练模型
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
- 加载一张测试图片
经典的猫图片
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# Preprocess the image and convert to tensor
from torchvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
- Import the graph to Relay
将Pytorch 图转换为Relay图,Input name是随意的
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
- Relay Build
使用给定的输入将graph图编译成llvm目标
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
本地执行过程中,执行结果如下:
URLError(OSError(99, 'Cannot assign requested address'))
Download attempt 0/3 failed, retrying.
URLError(OSError(99, 'Cannot assign requested address'))
Download attempt 1/3 failed, retrying.
WARNING:root:Failed to download tophub package for llvm: <urlopen error [Errno 99] Cannot assign requested address>
/home/workspace/tvm/tvm/python/tvm/driver/build_module.py:267: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
warnings.warn(
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
虽有Error报错,但似乎并没有影响(此处也奇怪,为什么会有URLError的报错呢,编译的时候还需要请求什么吗?不清楚,也许看过源码才知道为什么吧)
根据官方的示例结果,应该如下:
/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
"target_host parameter is going to be deprecated. "
- TVM上执行可移植图(Execute the portable graph on TVM)
现在,尝试在目标上部署已编译的模型
from tvm.contrib import graph_executor
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# Set inputs
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# Execute
m.run()
# Get outputs
tvm_output = m.get_output(0)
- 查找合集名称(Look up synset name)
在synset中查找预测的top1的索引
注:其中的imagenet_synsets.txt和imagenet_classes.txt两个文件,示例中给的url链接无效,可自行下载:
https://github.com/Cadene/pretrained-models.pytorch/blob/master/data/ 下的imagenet_synsets.txt和imagenet_classes.txt两个文件
亲测有效,可用
synset_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_synsets.txt",
]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synsets = f.readlines()
synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
class_url = "".join(
[
"https://raw.githubusercontent.com/Cadene/",
"pretrained-models.pytorch/master/data/",
"imagenet_classes.txt",
]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
class_id_to_key = f.readlines()
class_id_to_key = [x.strip() for x in class_id_to_key]
# Get top-1 result for TVM
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]
# Convert input to PyTorch variable and get PyTorch result for comparison
with torch.no_grad():
torch_img = torch.from_numpy(img)
output = model(torch_img)
# Get top-1 result for PyTorch
top1_torch = np.argmax(output.numpy())
torch_class_key = class_id_to_key[top1_torch]
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))
最后输出结果:
Relay top-1 id: 281, class name: tabby, tabby cat
Torch top-1 id: 281, class name: tabby, tabby cat
从一个例子看tvm执行流程的更多相关文章
- Spark小课堂Week7 从Spark中一个例子看面向对象设计
Spark小课堂Week7 从Spark中一个例子看面向对象设计 今天我们讨论了个问题,来设计一个Spark中的常用功能. 功能描述:数据源是一切处理的源头,这次要实现下加载数据源的方法load() ...
- 一个例子看懂所有nodejs的官方网络demo
今天看群里有人用AI技术写了个五子棋,正好用的socket.io,本身我自己很久没看nodejs了,再加上Tcp/IP的知识一直很弱,我就去官网看了下net.socket 发现之前以为懂的一个官方例子 ...
- 关于类、方法、对象(实例):通过一个例子看一下self都做了哪些事情
我们在定义一个类时,经常会在类的各个方法中看到self,那么在程序执行时self到底起了什么作用,什么时候要加self,这一点需要我们思考并好好理解.之前在学习时没有想这么多,加之用pycharm写代 ...
- 1020关于mysql一个简单语句的执行流程
MySQL的语句执行顺序 转自http://www.cnblogs.com/rollenholt/p/3776923.html MySQL的语句一共分为11步,如下图所标注的那样,最先执行的总是FRO ...
- 从一个例子看现代C++的威力
引子 最近准备重构一下我的kapok库,让meta函数可以返回元素为kv的tuple,例如: struct person { std::string name; int age; META(name, ...
- vc++深入跟踪MFC程序的执行流程
在MFC程序设计的学习过程中最令人感到难受,甚至于有时会动摇学习者信心的就是一种对于程序的一切细节都没有控制权的感觉.这种感觉来源于学习者不知道一个MFC程序是如何运行起来的(即一个MFC程序的执行流 ...
- 深入跟踪MFC程序的执行流程
来源: http://blog.csdn.net/ljianhui/article/details/8781991 在MFC程序设计的学习过程中最令人感到难受,甚至于有时会动摇学习者信心的就是一种对于 ...
- 从源码角度了解SpringMVC的执行流程
目录 从源码角度了解SpringMVC的执行流程 SpringMVC介绍 源码分析思路 源码解读 几个关键接口和类 前端控制器 DispatcherServlet 结语 从源码角度了解SpringMV ...
- Golang源码学习:调度逻辑(三)工作线程的执行流程与调度循环
本文内容主要分为三部分: main goroutine 的调度运行 非 main goroutine 的退出流程 工作线程的执行流程与调度循环. main goroutine 的调度运行 runtim ...
- springmvc执行流程详细介绍
1.什么是MVC MVC是Model View Controller的缩写,它是一个设计模式 2.springmvc执行流程详细介绍 第一步:发起请求到前端控制器(DispatcherServlet) ...
随机推荐
- 【软件开发】Git 概念与常用命令
[软件开发]Git 概念与常用命令 Git 概念 存储方式 Git 是分布式存储,每一个 clone 下来的仓库都可以看成独立的个体,只是 Git 有提供同步功能,因此 Git 支持离线使用,因为本质 ...
- macbookpro m3本地部署DeepSeek模型
macbookpro m3有着十分强大的性能.在deepseek如火如荼的当下,可以尝试在本地部署并使用.还可以将自己的文档作为语料喂给deepseek,使其能成为自己专属的AI助手. 本文介绍使用o ...
- 推荐一款最新开源,基于AI人工智能UI自动化测试工具!支持自然语言编写脚本!
随着互联网技术的飞速发展,Web应用越来越普及,前端页面也越来越复杂.为了确保产品质量,UI自动化测试成为了开发过程中不可或缺的一环.然而,传统的UI自动化测试工具往往存在学习成本高.维护困难等问题. ...
- Flink学习(十二) Sink到JDBC(可扩展到任何关系型数据库)
导入依赖 <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java& ...
- 浅谈Tox之一
本文分享自天翼云开发者社区<浅谈Tox之一>,作者:Moonriver What is tox? tox是通用的virtualenv管理和测试命令行工具,可用于: 使用不同的Python版 ...
- Angular CLI 源码分析
准备: 安装 Node.js https://nodejs.org/: 安装 VS Code https://code.visualstudio.com/: 创建文件夹 angular-cli-sou ...
- Vulnhub-FristiLeaks_1.3
一.靶机搭建 选择扫描虚拟机 选择路径即可 二.信息收集 靶机信息 产品名称:Fristileaks 1.3 作者:Ar0xA 发布日期: 2015 年 12 月 14 日 目标:获取root(uid ...
- 【数值计算方法】蒙特卡洛方法积分的Python实现
原理不做赘述,参见[数值计算方法]数值积分&微分-python实现 - FE-有限元鹰 - 博客园,直接上代码,只实现1d,2d积分,N维积分的蒙特卡洛方法也类似. 代码 from typin ...
- 认识webRTC
什么是 WebRTC 2010 年 5 月,谷歌收购了 Global IP Solutions(简称 GIPS),这是一家专注于 VoIP 和视频会议软件的公司,已开发出 RTC 所需的多项关键组件, ...
- Ansible忽略任务失败
在默认情况下,任务失败时会中止剧本任务,不过可以通过忽略失败的任务来覆盖此类行为.在可能出错且不影响全局的段中使用ignore_errors关键词来达到目的. 环境: 受控主机清单文件: [dev] ...