背景

在模型的部署中,为了高效利用硬件算力,常常会需要将多个输入组成一个batch同时输入网络进行推理,这个batch的大小根据系统的负载或者摄像头的路数时刻在变化,因此网络的输入batch是在动态变化的。对于pytorch等框架来说,我们并不会感受到这个问题,因为整个网络在pytorch中都是动态的。而在实际的工程化部署中,为了运行效率,却并不能有这样的灵活性。可能会有人说,那我就把batch固定在一个最大值,然后输入实际的batch,这样实际上网络是以最大batch在推理的,浪费了算力。所以我们需要能支持动态的batch,能够根据输入的batch数来运行。

一个常见的训练到部署的路径是:pytorch→onnx→tensorrt。在pytorch导出onnx时,我们可以指定输出为动态的输入:

  1. torch_out = torch.onnx.export(model, inp,
  2. save_path,input_names=["data"],output_names=["fc1"],dynamic_axes={
  3. "data":{0:'batch_size'},"fc1":{0:'batch_size'}
  4. })

而另一些时候,我们部署的模型来源于他人或开源模型,已经失去了原始的pytorch模型,此时如果onnx是静态batch的,在移植到tensorrt时,其输入就为静态输入了。想要动态输入,就需要对onnx模型本身进行修改了。另一方面,算法工程师在导模型的时候,如果没有指定输入层输出层的名称,导出的模型的层名有时候可读性比较差,比如输出是batchnorm_274这类名称,为了方便维护,也有需要对onnx的输入输出层名称进行修改。

操作

修改输入输出层

  1. def change_input_output_dim(model):
  2. # Use some symbolic name not used for any other dimension
  3. sym_batch_dim = "batch"
  4. # The following code changes the first dimension of every input to be batch-dim
  5. # Modify as appropriate ... note that this requires all inputs to
  6. # have the same batch_dim
  7. inputs = model.graph.input
  8. for input in inputs:
  9. # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
  10. # Add checks as needed.
  11. dim1 = input.type.tensor_type.shape.dim[0]
  12. # update dim to be a symbolic value
  13. dim1.dim_param = sym_batch_dim
  14. # or update it to be an actual value:
  15. # dim1.dim_value = actual_batch_dim
  16. outputs = model.graph.output
  17. for output in outputs:
  18. # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
  19. # Add checks as needed.
  20. dim1 = output.type.tensor_type.shape.dim[0]
  21. # update dim to be a symbolic value
  22. dim1.dim_param = sym_batch_dim
  23. model = onnx.load(onnx_path)
  24. change_input_output_dim(model)

通过将输入层和输出层的shape的第一维修改为非数字,就可以将onnx模型改为动态batch。

修改输入输出层名称

  1. def change_input_node_name(model, input_names):
  2. for i,input in enumerate(model.graph.input):
  3. input_name = input_names[i]
  4. for node in model.graph.node:
  5. for i, name in enumerate(node.input):
  6. if name == input.name:
  7. node.input[i] = input_name
  8. input.name = input_name
  9. def change_output_node_name(model, output_names):
  10. for i,output in enumerate(model.graph.output):
  11. output_name = output_names[i]
  12. for node in model.graph.node:
  13. for i, name in enumerate(node.output):
  14. if name == output.name:
  15. node.output[i] = output_name
  16. output.name = output_name

代码中input_names和output_names是我们希望改到的名称,做法是遍历网络,若有node的输入层名与要修改的输入层名称相同,则改成新的输入层名。输出层类似。

完整代码

  1. import onnx
  2. def change_input_output_dim(model):
  3. # Use some symbolic name not used for any other dimension
  4. sym_batch_dim = "batch"
  5. # The following code changes the first dimension of every input to be batch-dim
  6. # Modify as appropriate ... note that this requires all inputs to
  7. # have the same batch_dim
  8. inputs = model.graph.input
  9. for input in inputs:
  10. # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
  11. # Add checks as needed.
  12. dim1 = input.type.tensor_type.shape.dim[0]
  13. # update dim to be a symbolic value
  14. dim1.dim_param = sym_batch_dim
  15. # or update it to be an actual value:
  16. # dim1.dim_value = actual_batch_dim
  17. outputs = model.graph.output
  18. for output in outputs:
  19. # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
  20. # Add checks as needed.
  21. dim1 = output.type.tensor_type.shape.dim[0]
  22. # update dim to be a symbolic value
  23. dim1.dim_param = sym_batch_dim
  24. def change_input_node_name(model, input_names):
  25. for i,input in enumerate(model.graph.input):
  26. input_name = input_names[i]
  27. for node in model.graph.node:
  28. for i, name in enumerate(node.input):
  29. if name == input.name:
  30. node.input[i] = input_name
  31. input.name = input_name
  32. def change_output_node_name(model, output_names):
  33. for i,output in enumerate(model.graph.output):
  34. output_name = output_names[i]
  35. for node in model.graph.node:
  36. for i, name in enumerate(node.output):
  37. if name == output.name:
  38. node.output[i] = output_name
  39. output.name = output_name
  40. onnx_path = ""
  41. save_path = ""
  42. model = onnx.load(onnx_path)
  43. change_input_output_dim(model)
  44. change_input_node_name(model, ["data"])
  45. change_output_node_name(model, ["fc1"])
  46. onnx.save(model, save_path)

经过修改后的onnx模型输入输出将成为动态batch,可以方便的移植到tensorrt等框架以支持高效推理。

将onnx的静态batch改为动态batch及修改输入输出层的名称的更多相关文章

  1. Spark Streaming中动态Batch Size实现初探

    本期内容 : BatchDuration与 Process Time 动态Batch Size Spark Streaming中有很多算子,是否每一个算子都是预期中的类似线性规律的时间消耗呢? 例如: ...

  2. 《CMake实践》笔记三:构建静态库(.a) 与 动态库(.so) 及 如何使用外部共享库和头文件

    <CMake实践>笔记一:PROJECT/MESSAGE/ADD_EXECUTABLE <CMake实践>笔记二:INSTALL/CMAKE_INSTALL_PREFIX &l ...

  3. 浅谈在静态页面上使用动态参数,会造成spider多次和重复抓取的解决方案

    原因: 早期由于搜索引擎蜘蛛的不完善,蜘蛛在爬行动态的url的时候很容易由于网站程序的不合理等原因造成蜘蛛迷路死循环. 所以蜘蛛为了避免之前现象就不读取动态的url,特别是带?的url 解决方案: 1 ...

  4. 在Linux中创建静态库.a和动态库.so

    转自:http://www.cnblogs.com/laojie4321/archive/2012/03/28/2421056.html 在Linux中创建静态库.a和动态库.so 我们通常把一些公用 ...

  5. Spark Streaming揭秘 Day21 动态Batch size实现初探(下)

    Spark Streaming揭秘 Day21 动态Batch size实现初探(下) 接昨天的描述,今天继续解析动态Batch size调整的实现. 算法 动态调整采用了Fix-point迭代算法, ...

  6. Spark Streaming揭秘 Day20 动态Batch size实现初探(上)

    Spark Streaming揭秘 Day20 动态Batch size实现初探(上) 今天开始,主要是通过对动态Batch size调整的论文的解析,来进一步了解SparkStreaming的处理机 ...

  7. 动态库DLL加载方式-静态加载和动态加载

    静态加载: 如果你有a.dll和a.lib,两个文件都有的话可以用静态加载的方式: message函数的声明你应该知道吧,把它的声明和下面的语句写到一个头文件中 #pragma comment(lib ...

  8. WPF中静态引用资源与动态引用资源的区别

    WPF中静态引用资源与动态引用资源的区别   WPF中引用资源分为静态引用与动态引用,两者的区别在哪里呢?我们通过一个小的例子来理解. 点击“Update”按钮,第2个按钮的文字会变成“更上一层楼”, ...

  9. 解决在静态页面上使用动态参数,造成spider多次和重复抓取的问题

    我们在使用百度统计中的SEO建议检查网站时,总是发现“静态页参数”一项被扣了18分,扣分原因是“在静态页面上使用动态参数,会造成spider多次和重复抓取”.一般来说静态页面上使用少量的动态参数的话并 ...

  10. C++的静态联编和动态联编

    联编的概念 联编是指一个计算机程序自身彼此关联的过程,在这个联编过程中,需要确定程序中的操作调用(函数调用)与执行该操作(函数)的代码段之间的映射关系. 意思就是这个函数的实现有多种,联编就是把调用和 ...

随机推荐

  1. python入门教程之十八正则表达式

    re.match函数 re.match 尝试从字符串的起始位置匹配一个模式,如果不是起始位置匹配成功的话,match()就返回none. 函数语法: re.match(pattern, string, ...

  2. [Windows]CMD命令入门教程 与 Windows常见维护问题

    本博文最早是记录在本地电脑的,由于清理电脑的缘故,顺便将这篇笔记转移到公共博客,以便日后查阅和快速上手使用. 开门见山,步入正题,以下是Windows系统的常用CMD命令. ----2018-03-2 ...

  3. 三天吃透Redis八股文

    Redis连环40问,绝对够全! Redis是什么? Redis(Remote Dictionary Server)是一个使用 C 语言编写的,高性能非关系型的键值对数据库.与传统数据库不同的是,Re ...

  4. odoo 开发入门教程系列-继承(Inheritance)

    继承(Inheritance) Odoo的一个强大方面是它的模块化.模块专用于业务需求,但模块也可以相互交互.这对于扩展现有模块的功能非常有用.例如,在我们的房地产场景中,我们希望在常规用户视图中直接 ...

  5. C++重载的奥义之运算符重载

    0.引言 重载,顾名思义从字面上理解就是重复装载,打一个不恰当的比方,你可以用一个篮子装蔬菜,也可以装水果或者其它,使用的是同一个篮子,但是可以用篮子重复装载的东西不一样. 正如在之前的文章<重 ...

  6. 官宣 | Hugging Face 中文博客正式发布!

    作者:Tiezhen.Adina.Luke Hugging Face 的中国社区成立已经有五个月之久,我们也非常高兴的看到 Hugging Face 相关的中文内容在各个平台广受好评,我们也注意到,H ...

  7. Uniswap V2 — 从代码解释 DeFi 协议

    Uniswap V2 - 从代码解释 DeFi 协议 为了理解我们在分析代码时将要经历的不同组件,首先了解哪些是主要概念以及它们的作用是很重要的.所以,和我一起裸露吧,因为这是值得的. 我在 5 个段 ...

  8. Python 项目:外星人入侵--第三部分

    1.项目内容: 在屏幕左上角添加一个外星人,并指定合适的边框,根据第一个外星人的边距和屏幕尺寸计算屏幕上可容纳多少个外星人. 让外星人群向两边和下方移动,直到外星人被全部击落,有外星人撞到飞船,或有外 ...

  9. [Pytorch框架] 2.1.1 PyTorch 基础 : 张量

    文章目录 PyTorch 基础 : 张量 张量(Tensor) 基本类型 Numpy转换 设备间转换 初始化 常用方法 PyTorch 基础 : 张量 在第一章中我们已经通过官方的入门教程对PyTor ...

  10. Word中使用ChatGPT,写文档如有神助

    [部署教程]国内网络可用,最强 ChatGPT 学术论文写作工具原创****付费 简介 Word GPT Plus 是一个集成了 chatGPT 模型的 Word 插件.它允许你基于你在文档中写的内容 ...