python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发、低延迟的要求,存在需要用c++接口的情况。除了将模型导出为ONNX外,pytorch1.0给出了新的解决方案:pytorch 训练模型 - 通过torch script中间脚本保存模型 -- C++加载模型。最近工作需要尝试做了转换,总结一下步骤和遇到的坑。

用torch script把torch模型转成c++接口可读的模型有两种方式:trace && script. trace比script简单,但只适合结构固定的网络模型,即forward中没有控制流的情况,因为trace只会保存运行时实际走的路径。如果forward函数中有控制流,需要用script方式实现。

trace顾名思义,就是沿着数据运算的路径走一遍,官方例子:


import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

script稍复杂,主要改三处:

1. Model由之前继承 nn.Model 改为继承 torch.jit.ScriptModule

2. forward函数前加 @torch.jit.script_method

3. 其他需要调用的函数前加 @torch.jit.script

踩过的坑&&解决方法:

A. torch script默认函数或方法的参数都是Tensor类型的,如果不是需要说明,不然调用非Tensor参数时会报类型不符的编译错误。

python3可以直接:

def example_func(param_1: Tensor, param_2: int, param_3: List[int]):

python2需要用type注释:

def example_func(param_1, param_2, param_3):

#type: (Tensor, int, List[int]) -> Tensor

B. model的方法中forward加@torch.jit.script_method, __init__函数不用

C. 前面说过,torch scrip支持的函数是pytorch的子集,意味着有一部分函数不支持,例如: not boolean,pass, List的切片赋值,CPU和GPU切换的value.to( ), 需要想办法绕过去。看github上讨论区说新版好像已经支持not操作了,没有验证。

结论:pytorch 1.0目前的预览版还有比较多优化的空间,至少是在torch script支持的函数集合上,不建议使用,等稳定版发布再看看吧。

  

原创内容,转载请注明出处。

参考资料:

https://pytorch.org/docs/master/jit.html

https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html

pytorch1.0 用torch script导出模型的更多相关文章

  1. mysql数据库导出模型到powerdesigner,PDM图形窗口中显示数据列的中文注释

    1,mysql数据库导出模型到powerdesigner 2,CRL+Shift+X 3,复制以下内容,执行 '******************************************** ...

  2. 学习笔记TF022:产品环境模型部署、Docker镜像、Bazel工作区、导出模型、服务器、客户端

    产品环境模型部署,创建简单Web APP,用户上传图像,运行Inception模型,实现图像自动分类. 搭建TensorFlow服务开发环境.安装Docker,https://docs.docker. ...

  3. windows10 安装 Anaconda 并配置 pytorch1.0

    官网下载Anaconda安装包,按步骤安装即可安装完后,打开DOS,或Anaconda自带的Anaconda Prompt终端查看Anaconda已安装的安装包C:\Users\jiangshan&g ...

  4. revit导出模型数据到sqlserver数据库

    revit软件可以导出模型数据到sqlserver数据库,有时候,为了对模型做数据分析,需要导出模型的数据,下面总结一下导出过程: 首先在sqlserver中建立一个数据库,如:revit_wujin ...

  5. pytorch1.0进行Optimizer 优化器对比

    pytorch1.0进行Optimizer 优化器对比 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, ...

  6. pytorch1.0批训练神经网络

    pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...

  7. pytorch1.0神经网络保存、提取、加载

    pytorch1.0网络保存.提取.加载 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib.pyplot ...

  8. 用pytorch1.0快速搭建简单的神经网络

    用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...

  9. 用pytorch1.0搭建简单的神经网络:进行多分类分析

    用pytorch1.0搭建简单的神经网络:进行多分类分析 import torch import torch.nn.functional as F # 包含激励函数 import matplotlib ...

随机推荐

  1. socket编程 ------ UDP服务器

    void vLANcommunication( void *pvParameters ) { int32 listenfd; do{ listenfd = socket(AF_INET, SOCK_D ...

  2. angular cli全局版本大于本地版本 把本地版本升级方式

    查看 angular 版本  ng version 如出现提示 Your global Angular CLI version (xxx) is greater than your local ver ...

  3. Java集合、Iterator迭代器和增强for循环整理

    集合 集合,集合是java中提供的一种容器,可以用来存储多个数据. 数组的长度是固定的.集合的长度是可变的.集合中存储的元素必须是引用类型数据 1.1      ArrayList集合存储元素 pac ...

  4. 163邮箱SMTP设置

    如题要设置系统邮件自动发送 首先注册的网易邮箱要去开通SMTP服务,然后就是端口的设置 授权码 端口

  5. golang etcdclientv3使用说明

    clientv3.New() 创建连接 config = ec.Config{ Endpoints: []string{"10.0.0.5:2379"}, //连接的etcd集群地 ...

  6. source命令导入大数据速度慢优化

    XX市邮政微商城的项目数据库,300多M,约220万条数据,source命令导入花了20个小时左右,太不可思议. 速度慢原因:220多万条数据,就 insert into 了220多万次,下图: 这是 ...

  7. 学习go语言编程系列之helloworld

    1. 下载https://golang.org/dl/ # Go语言官网地址,在国内下载太慢,甚至都无法访问.通过如下地址下载:https://golangtc.com/download. 2. 安装 ...

  8. Mac 键盘符号 及VSCode快捷键 说明

    Mac 键盘符号说明 ⌘ == Command ⇧ == Shift ⇪ == Caps Lock ⌥ == Option ⌃ == Control ↩ == Return/Enter ⌫ == De ...

  9. springboot(十八):解决跨域问题

    在controller上添加@CrossOrigin注解,如下: @RestController @RequestMapping("course") @CrossOrigin pu ...

  10. opencv实现坐标旋转(教你框住小姐姐)

    一.项目背景 最近在做一个人脸检测项目,需要接入百度AI的系统进行识别和检测.主要流程就是往指定的URL上post图片上去,之后接收检测结果就好了. 百度的检测结果包含这样的信息: left - 人脸 ...