接上一篇完成的pytorch模型训练结果,模型结构为ResNet18+fc,参数量约为11M,最终测试集Acc达到94.83%。接下来有分两个部分:导出onnx和使用onnxruntime推理。

一、pytorch导出onnx

直接放函数吧,这部分我是直接放在test.py里面的,直接从dataloader中拿到一个batch的数据走一遍推理即可。

def export_onnx(net, testloader, output_file):
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data torch.onnx.export(net,
(images),
output_file,
training=False,
do_constant_folding=True,
input_names=["img"],
output_names=["output"],
dynamic_axes={"img": {0: "b"},"output": {0: "b"}}
)
print("onnx export done!")
break

上面函数中几个比较重要的参数:do_constant_folding是常量折叠,建议打开;输入张量通过一个tuple传入,并且最好指定每个输入和输出的名称,此外,为保证使用onnxruntime推理的时候batchsize可变,dynamic_axes的第一维需要像上述一样设置为动态的。如果是全卷积做分割的网络,类似的输入h和w也应该是动态的。

单独运行test.py计算测试集效果和平均相应时间,为方便比较,这里batch_size设置为1,结果为:

Test Acc is: 94.84%
Average response time cost: 8.703978610038757 ms

二、使用onnxruntime推理

这里我们使用gpu版本的onnxruntime库进行推理,其python包可直接pip install onnxruntime-gpu安装。onnxruntime推理代码和测试集推理代码很类似,如下:

import numpy as np
import onnxruntime as ort
import argparse, os
from lib import CIFARDataset def onnxruntime_test(session, testloader):
print("Start Testing!")
input_name = session.get_inputs()[0].name
correct = 0
total = 0 # 计数归零(初始化)
for data in testloader:
images, labels = data
images, labels = images.numpy(), labels.numpy()
outputs = session.run(None, {input_name:images})
predicted = np.argmax(outputs[0], axis=1) # 取得分最高的那个类
total += labels.shape[0] # 累加样本总数
correct += (predicted == labels).sum() # 累加预测正确的样本个数
acc = correct / total
print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc)) if __name__ == '__main__':
# 命令行参数解析
parser = argparse.ArgumentParser("CNN backbone on cifar10")
parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx')
args = parser.parse_args() NUM_CLASS =10
BATCH_SIZE = 1 # 批处理尺寸(batch_size) # 数据集迭代器
data_path="./data"
dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE)
_, testloader = dataset.get_cifar10_dataloader() # 构建session
sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) #onnxruntime推理
import time
start = time.time()
onnxruntime_test(sess, testloader)
end = time.time()
print(f"Average response time cost: {1000*(end-start)/len(testloader.dataset)} ms")

使用onnxruntime加载导出的onnx模型,计算测试集效果和平均响应时间,结果为:

ONNXRuntime Test Acc is: 94.83%
Average response time cost: 3.1050602436065673 ms

三、小结

分析上面的pytorch和onnxruntime的测试结果可知,最终测试集效果是一致的,Acc分别为94.84%和94.83%,相当于10000个样本里面只有1个的预测结果不一致,这是可以接受范围内。但onnxruntime的效率更高,平均耗时只有3.1ms,比pytorch的8.7ms快了将近3倍。这在实际部署中的优势是非常明显的。目前Python端的结论比最初目标设定的50ms高很多,如果说需要进一步优化,两个方向:模型量化或并行化推理(拼batch或多线程)。下一篇再分析。

ONNXRuntime学习笔记(三)的更多相关文章

  1. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

  2. [Firefly引擎][学习笔记三][已完结]所需模块封装

    原地址:http://www.9miao.com/question-15-54671.html 学习笔记一传送门学习笔记二传送门 学习笔记三导读:        笔记三主要就是各个模块的封装了,这里贴 ...

  3. JSP学习笔记(三):简单的Tomcat Web服务器

    注意:每次对Tomcat配置文件进行修改后,必须重启Tomcat 在E盘的DATA文件夹中创建TomcatDemo文件夹,并将Tomcat安装路径下的webapps/ROOT中的WEB-INF文件夹复 ...

  4. java之jvm学习笔记三(Class文件检验器)

    java之jvm学习笔记三(Class文件检验器) 前面的学习我们知道了class文件被类装载器所装载,但是在装载class文件之前或之后,class文件实际上还需要被校验,这就是今天的学习主题,cl ...

  5. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

  6. Java IO学习笔记三

    Java IO学习笔记三 在整个IO包中,实际上就是分为字节流和字符流,但是除了这两个流之外,还存在了一组字节流-字符流的转换类. OutputStreamWriter:是Writer的子类,将输出的 ...

  7. NumPy学习笔记 三 股票价格

    NumPy学习笔记 三 股票价格 <NumPy学习笔记>系列将记录学习NumPy过程中的动手笔记,前期的参考书是<Python数据分析基础教程 NumPy学习指南>第二版.&l ...

  8. Learning ROS for Robotics Programming Second Edition学习笔记(三) 补充 hector_slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  9. Learning ROS for Robotics Programming Second Edition学习笔记(三) indigo rplidar rviz slam

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

随机推荐

  1. JVM的小总结(转)

    ref:http://www.cnblogs.com/ityouknow/p/6482464.html 注1:看了大神:纯洁的微笑的JVM系列篇,发现好多地方还是似懂非懂,理解的并不透彻,jvm的调优 ...

  2. springDataRedis忽略实体指定的属性

    如果是通过 RedisRepository定义的实体,可能存在想要忽略的属性,那么,就可以 使用 org.springframework.data.annotation.Transient 注解,就可 ...

  3. SublimeText 建立构建Node js系统

    Sublime Text 3 构建系统:https://www.sublimetext.com/docs/3/build_systems.html 注意: 文档中出现的 shell_cmd 和 cmd ...

  4. 设置python 虚拟环境 virtualenv django 虚拟环境

    https://developer.mozilla.org/en-US/docs/Learn/Server-side/Django/development_environment Ubuntu vir ...

  5. c、c++中-int型以float或者float型以int输出问题

    1.将浮点型以整形的类型输出问题 用VC6.0,会把以整形输出形式的浮点数输出为0: 1 #include"stdio.h" 2 int main() 3 { 4 float x= ...

  6. (Math.round(num*100)/100).toFixed(2); 将输入的数字变成保留两位小数

    <input type="number" @input="onInputPrice" @blur="onPrice" data-id= ...

  7. Centos搭建 LAMP 服务器教程

    搭建 LAMP 服务 搭建 MySQL 数据库 安装 MySQL 使用 yum 安装 MySQL: yum install mysql-server -y 安装完成后,启动 MySQL 服务: ser ...

  8. 常见的JVM 面试题

    1.讲一讲JVM的跨平台与跨语言 跨平台 我们写的一个类,在不同的操作系统上(Linux.windows.Mac OS)执行,效果是一样的.这就是JVM的跨平台性. 跨语言 JVM只识别字节码,JVM ...

  9. springcloud报错:org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'armeriaServer' defined in class path resource

    spring boot配置zipkin 无法启动 加入 Zipkin Server 由于需要收集 Spring Cloud 系统的跟踪信息,以便及时地发现系统中出现的延迟升高问题并找出系统性能瓶颈的根 ...

  10. Java对象和多态

    Java对象和多态 (面向对象) 面向对象基础 面向对象程序设计(Object Oriented Programming) 对象基于类创建,类相当于一个模板,对象就是根据模板创建出来的实体(就像做月饼 ...