背景

在模型的部署中,为了高效利用硬件算力,常常会需要将多个输入组成一个batch同时输入网络进行推理,这个batch的大小根据系统的负载或者摄像头的路数时刻在变化,因此网络的输入batch是在动态变化的。对于pytorch等框架来说,我们并不会感受到这个问题,因为整个网络在pytorch中都是动态的。而在实际的工程化部署中,为了运行效率,却并不能有这样的灵活性。可能会有人说,那我就把batch固定在一个最大值,然后输入实际的batch,这样实际上网络是以最大batch在推理的,浪费了算力。所以我们需要能支持动态的batch,能够根据输入的batch数来运行。

一个常见的训练到部署的路径是:pytorch→onnx→tensorrt。在pytorch导出onnx时,我们可以指定输出为动态的输入:

torch_out = torch.onnx.export(model, inp,
save_path,input_names=["data"],output_names=["fc1"],dynamic_axes={
"data":{0:'batch_size'},"fc1":{0:'batch_size'}
})

而另一些时候,我们部署的模型来源于他人或开源模型,已经失去了原始的pytorch模型,此时如果onnx是静态batch的,在移植到tensorrt时,其输入就为静态输入了。想要动态输入,就需要对onnx模型本身进行修改了。另一方面,算法工程师在导模型的时候,如果没有指定输入层输出层的名称,导出的模型的层名有时候可读性比较差,比如输出是batchnorm_274这类名称,为了方便维护,也有需要对onnx的输入输出层名称进行修改。

操作

修改输入输出层

def change_input_output_dim(model):
# Use some symbolic name not used for any other dimension
sym_batch_dim = "batch" # The following code changes the first dimension of every input to be batch-dim
# Modify as appropriate ... note that this requires all inputs to
# have the same batch_dim
inputs = model.graph.input
for input in inputs:
# Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
# Add checks as needed.
dim1 = input.type.tensor_type.shape.dim[0]
# update dim to be a symbolic value
dim1.dim_param = sym_batch_dim
# or update it to be an actual value:
# dim1.dim_value = actual_batch_dim outputs = model.graph.output
for output in outputs:
# Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
# Add checks as needed.
dim1 = output.type.tensor_type.shape.dim[0]
# update dim to be a symbolic value
dim1.dim_param = sym_batch_dim model = onnx.load(onnx_path)
change_input_output_dim(model)

通过将输入层和输出层的shape的第一维修改为非数字,就可以将onnx模型改为动态batch。

修改输入输出层名称

def change_input_node_name(model, input_names):
for i,input in enumerate(model.graph.input):
input_name = input_names[i]
for node in model.graph.node:
for i, name in enumerate(node.input):
if name == input.name:
node.input[i] = input_name
input.name = input_name def change_output_node_name(model, output_names):
for i,output in enumerate(model.graph.output):
output_name = output_names[i]
for node in model.graph.node:
for i, name in enumerate(node.output):
if name == output.name:
node.output[i] = output_name
output.name = output_name

代码中input_names和output_names是我们希望改到的名称,做法是遍历网络,若有node的输入层名与要修改的输入层名称相同,则改成新的输入层名。输出层类似。

完整代码

import onnx
def change_input_output_dim(model):
# Use some symbolic name not used for any other dimension
sym_batch_dim = "batch" # The following code changes the first dimension of every input to be batch-dim
# Modify as appropriate ... note that this requires all inputs to
# have the same batch_dim
inputs = model.graph.input
for input in inputs:
# Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
# Add checks as needed.
dim1 = input.type.tensor_type.shape.dim[0]
# update dim to be a symbolic value
dim1.dim_param = sym_batch_dim
# or update it to be an actual value:
# dim1.dim_value = actual_batch_dim outputs = model.graph.output
for output in outputs:
# Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
# Add checks as needed.
dim1 = output.type.tensor_type.shape.dim[0]
# update dim to be a symbolic value
dim1.dim_param = sym_batch_dim def change_input_node_name(model, input_names):
for i,input in enumerate(model.graph.input):
input_name = input_names[i]
for node in model.graph.node:
for i, name in enumerate(node.input):
if name == input.name:
node.input[i] = input_name
input.name = input_name def change_output_node_name(model, output_names):
for i,output in enumerate(model.graph.output):
output_name = output_names[i]
for node in model.graph.node:
for i, name in enumerate(node.output):
if name == output.name:
node.output[i] = output_name
output.name = output_name onnx_path = ""
save_path = ""
model = onnx.load(onnx_path)
change_input_output_dim(model)
change_input_node_name(model, ["data"])
change_output_node_name(model, ["fc1"]) onnx.save(model, save_path)

经过修改后的onnx模型输入输出将成为动态batch,可以方便的移植到tensorrt等框架以支持高效推理。

将onnx的静态batch改为动态batch及修改输入输出层的名称的更多相关文章

  1. Spark Streaming中动态Batch Size实现初探

    本期内容 : BatchDuration与 Process Time 动态Batch Size Spark Streaming中有很多算子,是否每一个算子都是预期中的类似线性规律的时间消耗呢? 例如: ...

  2. 《CMake实践》笔记三:构建静态库(.a) 与 动态库(.so) 及 如何使用外部共享库和头文件

    <CMake实践>笔记一:PROJECT/MESSAGE/ADD_EXECUTABLE <CMake实践>笔记二:INSTALL/CMAKE_INSTALL_PREFIX &l ...

  3. 浅谈在静态页面上使用动态参数,会造成spider多次和重复抓取的解决方案

    原因: 早期由于搜索引擎蜘蛛的不完善,蜘蛛在爬行动态的url的时候很容易由于网站程序的不合理等原因造成蜘蛛迷路死循环. 所以蜘蛛为了避免之前现象就不读取动态的url,特别是带?的url 解决方案: 1 ...

  4. 在Linux中创建静态库.a和动态库.so

    转自:http://www.cnblogs.com/laojie4321/archive/2012/03/28/2421056.html 在Linux中创建静态库.a和动态库.so 我们通常把一些公用 ...

  5. Spark Streaming揭秘 Day21 动态Batch size实现初探(下)

    Spark Streaming揭秘 Day21 动态Batch size实现初探(下) 接昨天的描述,今天继续解析动态Batch size调整的实现. 算法 动态调整采用了Fix-point迭代算法, ...

  6. Spark Streaming揭秘 Day20 动态Batch size实现初探(上)

    Spark Streaming揭秘 Day20 动态Batch size实现初探(上) 今天开始,主要是通过对动态Batch size调整的论文的解析,来进一步了解SparkStreaming的处理机 ...

  7. 动态库DLL加载方式-静态加载和动态加载

    静态加载: 如果你有a.dll和a.lib,两个文件都有的话可以用静态加载的方式: message函数的声明你应该知道吧,把它的声明和下面的语句写到一个头文件中 #pragma comment(lib ...

  8. WPF中静态引用资源与动态引用资源的区别

    WPF中静态引用资源与动态引用资源的区别   WPF中引用资源分为静态引用与动态引用,两者的区别在哪里呢?我们通过一个小的例子来理解. 点击“Update”按钮,第2个按钮的文字会变成“更上一层楼”, ...

  9. 解决在静态页面上使用动态参数,造成spider多次和重复抓取的问题

    我们在使用百度统计中的SEO建议检查网站时,总是发现“静态页参数”一项被扣了18分,扣分原因是“在静态页面上使用动态参数,会造成spider多次和重复抓取”.一般来说静态页面上使用少量的动态参数的话并 ...

  10. C++的静态联编和动态联编

    联编的概念 联编是指一个计算机程序自身彼此关联的过程,在这个联编过程中,需要确定程序中的操作调用(函数调用)与执行该操作(函数)的代码段之间的映射关系. 意思就是这个函数的实现有多种,联编就是把调用和 ...

随机推荐

  1. [Java]排序算法>插入排序>【折半插入排序】(O(N*N)/稳定/N较大/无序/顺序存储)

    1 折半插入排序 1.1 算法思想 相比于[直接插入排序]:采用"顺序查找法"查找当前记录在已排好序的序列中的插入位置, 折半插入排序利用"折半查找法"快速查出 ...

  2. YII2.0使用ActiveForm表单

    Controller控制器层代码 <?php namespace frontend\controllers; use frontend\models\UserForm; class UserCo ...

  3. 数组描述线性表(C++实现)

    线性表也称有序表,其每一个实例都是元素的一个有序集合 抽象类linearList 一个抽象类包含没有实现代码的成员函数,这样的成员函数称为纯虚函数,用数字0作为初始值来说明 template<c ...

  4. vue3.0

    https://www.yuque.com/gdnnth/vue-v3 http://www.liulongbin.top:8085/#/ https://www.yuque.com/woniuppp ...

  5. 如何将 Spire.Doc for C++ 集成到 C++ 程序中

    Spire.Doc for C++是一个专业的 Word 库,供开发人员在任何类型的 C++ 应用程序中阅读.创建.编辑.比较和转换 Word 文档. 本文演示了如何以两种不同的方式将 Spire.D ...

  6. JUC并发常用工具学习

    今天主要来和大家分享一下JUC相关的一些简单知识,线程池文章就不介绍了,前面的文章有介绍,本文主要介绍Lock和认识synchronized和并发的一些工具类的使用. Lock 传统的锁有synchr ...

  7. 求解 LCA の方法

    最近公共祖先(LCA) 最近公共祖先简称 LCA(Lowest Common Ancestor).两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个. -----oi wiki 举个例 ...

  8. JVM面试和学习中需要注意的部分

    内存结构 1.方法区用来存储类加载的数据,例如类的名称,方法入口 2.JVM虚拟机栈用于存储线程,包括局部变量和方法参数 3.堆内存用来存储对象 4.方法区的规范实现:永久代和元空间 5.方法区 JV ...

  9. 基于APM模式的异步实现及跨线程操作窗体或控件方法的实现示例

    最近在一家某电力外派公司开发相关于GIS的功能,在实现代码的过程中出现了一些常见的问题比如: 1.跨线程执行窗体或控件操作(直接使用委拖) 2.异步模式执行某长时间耗时方法 经过一系列摸索可算找到解决 ...

  10. BUG解决-Vscode/Sublime C++ 打印中文乱码问题

    #include <iostream> using namespace std; #ifdef _WIN32 #include <windows.h> #endif int m ...