pytorch1.0 用torch script导出模型
python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发、低延迟的要求,存在需要用c++接口的情况。除了将模型导出为ONNX外,pytorch1.0给出了新的解决方案:pytorch 训练模型 - 通过torch script中间脚本保存模型 -- C++加载模型。最近工作需要尝试做了转换,总结一下步骤和遇到的坑。
用torch script把torch模型转成c++接口可读的模型有两种方式:trace && script. trace比script简单,但只适合结构固定的网络模型,即forward中没有控制流的情况,因为trace只会保存运行时实际走的路径。如果forward函数中有控制流,需要用script方式实现。
trace顾名思义,就是沿着数据运算的路径走一遍,官方例子:
import torch |
script稍复杂,主要改三处:
1. Model由之前继承 nn.Model 改为继承 torch.jit.ScriptModule
2. forward函数前加 @torch.jit.script_method
3. 其他需要调用的函数前加 @torch.jit.script
踩过的坑&&解决方法:
A. torch script默认函数或方法的参数都是Tensor类型的,如果不是需要说明,不然调用非Tensor参数时会报类型不符的编译错误。
python3可以直接:
| def example_func(param_1: Tensor, param_2: int, param_3: List[int]): |
python2需要用type注释:
|
def example_func(param_1, param_2, param_3): #type: (Tensor, int, List[int]) -> Tensor |
B. model的方法中forward加@torch.jit.script_method, __init__函数不用
C. 前面说过,torch scrip支持的函数是pytorch的子集,意味着有一部分函数不支持,例如: not boolean,pass, List的切片赋值,CPU和GPU切换的value.to( ), 需要想办法绕过去。看github上讨论区说新版好像已经支持not操作了,没有验证。
结论:pytorch 1.0目前的预览版还有比较多优化的空间,至少是在torch script支持的函数集合上,不建议使用,等稳定版发布再看看吧。
原创内容,转载请注明出处。
参考资料:
https://pytorch.org/docs/master/jit.html
https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html
pytorch1.0 用torch script导出模型的更多相关文章
- mysql数据库导出模型到powerdesigner,PDM图形窗口中显示数据列的中文注释
1,mysql数据库导出模型到powerdesigner 2,CRL+Shift+X 3,复制以下内容,执行 '******************************************** ...
- 学习笔记TF022:产品环境模型部署、Docker镜像、Bazel工作区、导出模型、服务器、客户端
产品环境模型部署,创建简单Web APP,用户上传图像,运行Inception模型,实现图像自动分类. 搭建TensorFlow服务开发环境.安装Docker,https://docs.docker. ...
- windows10 安装 Anaconda 并配置 pytorch1.0
官网下载Anaconda安装包,按步骤安装即可安装完后,打开DOS,或Anaconda自带的Anaconda Prompt终端查看Anaconda已安装的安装包C:\Users\jiangshan&g ...
- revit导出模型数据到sqlserver数据库
revit软件可以导出模型数据到sqlserver数据库,有时候,为了对模型做数据分析,需要导出模型的数据,下面总结一下导出过程: 首先在sqlserver中建立一个数据库,如:revit_wujin ...
- pytorch1.0进行Optimizer 优化器对比
pytorch1.0进行Optimizer 优化器对比 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, ...
- pytorch1.0批训练神经网络
pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...
- pytorch1.0神经网络保存、提取、加载
pytorch1.0网络保存.提取.加载 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot ...
- 用pytorch1.0快速搭建简单的神经网络
用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...
- 用pytorch1.0搭建简单的神经网络:进行多分类分析
用pytorch1.0搭建简单的神经网络:进行多分类分析 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib ...
随机推荐
- opencv: 角点检测源码分析;
以下6个函数是opencv有关角点检测的函数 ConerHarris, cornoerMinEigenVal,CornorEigenValsAndVecs, preConerDetect, coner ...
- python自动化开发-[第四天]-函数
今日概要: - 函数对象 - 函数嵌套 - 命名空间和作用域 - 闭包 - 装饰器 - 迭代器 - 生成器 - 内置函数 一.函数对象 1.函数对象的定义: 函数是第一类对象,即函数可以当作数据传递 ...
- [报错]Could not get a resource from the pool
redis.clients.jedis.exceptions.JedisConnectionException: Could not get a resource from the pool解决:开启 ...
- mybatis无mapper.xml用法
在datasource配置类上加上 @MapperScan("cn.x.x.dao")@Configuration <project xmlns="http://m ...
- 网络编程基础【day09】:socket实现文件发送(六)
本节内容 1.概述 2.文件下载实现 3.MD5值校验 一.概述 我们如何利用socket去下载一个文件,整体思路是这样的: 读取文件名 检测文件是否存在 打开文件 检测文件大小 发送文件大小给客户端 ...
- MapReduce-WordCount
pom.xml <?xml version="1.0" encoding="UTF-8"?> <project xmlns="htt ...
- C语言复习---找出报数最后一人
题意: 有n个人围成一圈 顺序排号 从第1个人开始报数(从1到3报数),凡报到3的人退出圈子,问最后留下的是原来第几号的那位. 算法实现: (一)一种是按照链表数据结构(一)线性表循环链表之约瑟夫环 ...
- [NIO-4]选择器
选择器 最后,我们探索一下选择器.由于选择器内容比较多,所以本篇先偏理论地讲一下,后一篇讲代码,文章也没有什么概括.总结的,写到哪儿算哪儿了,只求能将选择器写明白,并且将一些相对重要的内容加粗标红. ...
- Golang入门教程(十六)Goridge -高性能的 PHP-to-Golang RPC编解码器库
什么是 RPC 框架? RPC(Remote Procedure Call)—远程过程调用,它是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络技术的协议.RPC协议假定某些传输协议的存在 ...
- Map接口、HashMap类、LinkedHashSet类
java.util.Map<K, V>接口 双列集合,key不可以重复 Map方法: 1.public V put(K key, V value):键值对添加到map,如果key不重复返回 ...