ONNX预训练模型加载
tvm官网中,对从ONNX预训练模型中加载模型的教程说明
教程来自于:https://docs.tvm.ai/tutorials/frontend/from_onnx.html#sphx-glr-tutorials-frontend-from-onnx-py
首先我对教程进行了一些修改,很多东西没有必要,比如不是每次都需要从网上下载图片和模型,super_resolution.onnx和cat.png都预先下载到了文件同目录下,
同时,最新版本的tvm中不支持Python2.7,我没有编译llvm,所以我把我的设置都改到了cuda上,在24行和32行有体现,注意最新版本
import onnx
import numpy as np
import tvm
import tvm.relay as relay
# from tvm.contrib.download import download_testdata # model_url = ''.join(['https://gist.github.com/zhreshold/',
# 'bcda4716699ac97ea44f791c24310193/raw/',
# '93672b029103648953c4e5ad3ac3aadf346a4cdc/',
# 'super_resolution_0.2.onnx'])
# model_path = download_testdata(model_url, 'super_resolution.onnx', module='onnx')
# now you have super_resolution.onnx on disk
onnx_model = onnx.load('super_resolution.onnx') from PIL import Image
# img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
# img_path = download_testdata(img_url, 'cat.png', module='data')
img_path = 'cat.png'
img = Image.open(img_path).resize((224, 224))
img_ycbcr = img.convert("YCbCr") # convert to YCbCr
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :] target = 'cuda' input_name = ''
shape_dict = {input_name: x.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
print(sym) with relay.build_config(opt_level=1):
intrp = relay.build_module.create_executor('graph', sym, tvm.gpu(0), target) dtype = 'float32'
tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
第28行有一个从模型加载的函数from_onnx
官方的解释:tvm.relay.frontend.
from_onnx
(model, shape=None, dtype='float32')
Convert a ONNX model into an equivalent Relay Function.
ONNX graphs are represented as Python Protobuf objects. The companion parameters will be handled automatically. However, the input names from onnx graph is vague, mixing inputs and network weights/bias such as “1”, “2”… For convenience, we rename the real input names to “input_0”, “input_1”… And renaming parameters to “param_0”, “param_1”…
Parameters: |
|
---|---|
Returns: |
|
看返回值,sym是relay Function,在后边加一个print(sym)输出,可以看到图这一级的IR
fn (%v1: Tensor[(1, 1, 224, 224), float32], %v2: Tensor[(64, 1, 5, 5), float32], %v3: Tensor[(64,), float32], %v4: Tensor[(64, 64, 3, 3), float32], %v5: Tensor[(64,), float32], %v6: Tensor[(32, 64, 3, 3), float32], %v7: Tensor[(32,), float32], %v8: Tensor[(9, 32, 3, 3), float32], %v9: Tensor[(9,), float32]) {
%0 = nn.conv2d(%v1, %v2, padding=[2, 2], kernel_size=[5, 5])
%1 = expand_dims(%v3, axis=1, num_newaxis=2)
%2 = add(%0, %1)
%3 = nn.relu(%2)
%4 = nn.conv2d(%3, %v4, padding=[1, 1], kernel_size=[3, 3])
%5 = expand_dims(%v5, axis=1, num_newaxis=2)
%6 = add(%4, %5)
%7 = nn.relu(%6)
%8 = nn.conv2d(%7, %v6, padding=[1, 1], kernel_size=[3, 3])
%9 = expand_dims(%v7, axis=1, num_newaxis=2)
%10 = add(%8, %9)
%11 = nn.relu(%10)
%12 = nn.conv2d(%11, %v8, padding=[1, 1], kernel_size=[3, 3])
%13 = expand_dims(%v9, axis=1, num_newaxis=2)
%14 = add(%12, %13)
%15 = reshape(%14, newshape=[1, 1, 3, 3, 224, 224])
%16 = transpose(%15, axes=[0, 1, 4, 2, 5, 3])
reshape(%16, newshape=[1, 1, 672, 672])
}
ONNX预训练模型加载的更多相关文章
- [.NET6]使用ML.NET+ONNX预训练模型整活B站经典《华强买瓜》
最近在看微软开源的机器学习框架ML.NET使用别人的预训练模型(开放神经网络交换格式.onnx)来识别图像,然后逛github发现一个好玩的repo.决定整活一期博客. 首先还是稍微科普一下机器学习相 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- 关于document.write()加载JS等静态资源 和 异步async加载JS
现流行浏览器对于静态资源的预加载 传统的浏览器,对于静态资源加载,会阻塞 HTML 解析器的线程进行,无论内联还是外链. 例如: <script src="test1.js" ...
- [Pytorch]Pytorch加载预训练模型(转)
转自:https://blog.csdn.net/Vivianyzw/article/details/81061765 东风的地方 1. 直接加载预训练模型 在训练的时候可能需要中断一下,然后继续训练 ...
- 如何使用 opencv 加载 darknet yolo 预训练模型?
如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...
- 【tf.keras】tf.keras加载AlexNet预训练模型
目录 从 PyTorch 中导出模型参数 第 0 步:配置环境 第 1 步:安装 MMdnn 第 2 步:得到 PyTorch 保存完整结构和参数的模型(pth 文件) 第 3 步:导出 PyTorc ...
- PyTorch-网络的创建,预训练模型的加载
本文是PyTorch使用过程中的的一些总结,有以下内容: 构建网络模型的方法 网络层的遍历 各层参数的遍历 模型的保存与加载 从预训练模型为网络参数赋值 主要涉及到以下函数的使用 add_module ...
- pytorch中修改后的模型如何加载预训练模型
问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...
随机推荐
- Connection
作用: * 获取执行sql语句对象 ** createStatement(): 获取Statement对象 ** prepareStatement(String sql): 获取预处理对象 ** pr ...
- DriverManager
1: 注册驱动 Class.forName("com.mysql.jdbc.Driver") ; static { try { java.sql.DriverManager.reg ...
- 【UE】常用的UltraEdit使用技巧
Tip 1: Alt+C 列模式可以说最初选择使用这个文本编辑软件,原因很简单,就是因为“她”具有列编辑模式.如果您还不知道什么是列编辑模式的话,我想您应该好好研究一下啦.这是一个超级“赞”的功能.在 ...
- 修改阿里源为Ubuntu 18.04默认的源
步骤如下: Step1:备份/etc/apt/sources.list sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak Step2:在/ ...
- 冲刺Noip2017模拟赛4 解题报告——五十岚芒果酱
题1 韬韬抢苹果(apple) [问题描述] 又到了收获的季节,树上结了许多韬韬,错了,是许多苹果,有很多个小韬韬都来摘苹 果.每个韬韬都想要最大的苹果,所以发生了争执,为了解决他们的矛盾,出题人定了 ...
- 解决从github上下载代码仓库慢的问题
一,打开命令提示符,最好之前准备一个仓库地址,这样下载下来的文件方便查看,这里打开你想要的下载根目录,进行下载. github上下载代码仓库慢的问题"> 二:复制代码仓库的地址 三:右 ...
- POJ2594 Treasure Exploration【DAG有向图可相交的最小路径覆盖】
题目链接:http://poj.org/problem?id=2594 Treasure Exploration Time Limit: 6000MS Memory Limit: 65536K T ...
- 【转贴】linux 终端报Message from syslogd
linux 终端报Message from syslogd xiao9873341760人评论8537人阅读2017-03-27 14:19:31 https://blog.51cto.com/xia ...
- 【AI】【人工智能】【计算机】人工智能工程技术人员等职业信息公示
人社厅发[2019]48号 各省.自治区.直辖市及新疆生产建设兵团人力资源社会保障厅(局).市场监管局.统计局,国务院各部门.各直属机构.各中央企业.有关社会组织人事劳动保障工作机构,中央军委政治工作 ...
- k-近邻(KNN) 算法预测签到位置
分类算法-k近邻算法(KNN): 定义: 如果一个样本在特征空间中的k个最相似 (即特征空间中最邻近) 的样本中的大多数属于某一个类别,则该样本也属于这个类别 来源: KNN算法最早是由Cover和H ...