生产与学术之Pytorch模型导出为安卓Apk尝试记录
生产与学术
写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下
生产与学术, 真实的对立...
这是我这两天对pytorch深度学习->android实际使用的这个流程的一个切身感受.
说句实在的, 对于模型转换的探索, 算是我这两天最大的收获了...
全部浓缩在了这里: https://github.com/lartpang/DHSNet-PyTorch/blob/master/converter.ipynb
鉴于github加载ipynb太慢, 这里可以使用这个链接 https://nbviewer.jupyter.org/github/lartpang/DHSNet-PyTorch/blob/master/converter.ipynb
这两天
最近在研究将pytorch的模型转换为独立的app, 网上寻找, 找到了一个流程: pytorch->onnx->caffe2->android apk. 主要是基于这篇文章的启发: caffe2&pytorch之在移动端部署深度学习模型(全过程!).
这两天就在折腾这个工具链,为了导出onnx的模型, 不确定要基于怎样的网络, 是已经训练好的, 还是原始搭建网络后再训练来作为基础. 所以不断地翻阅pytorch和onnx的官方示例, 想要研究出来点什么, 可是, 都是自己手动搭建的模型. 而且使用的是预训练权重, 不是这样:
def squeezenet1_1(pretrained=False, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
than SqueezeNet 1.0, without sacrificing accuracy.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.1, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
# Get pretrained squeezenet model
torch_model = squeezenet1_1(True)
from torch.autograd import Variable
batch_size = 1 # just a random number
# Input to the model
x = Variable(torch.randn(batch_size, 3, 224, 224), requires_grad=True)
# Export the model
torch_out = torch.onnx._export(
torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"squeezenet.onnx", # where to save the model (can be a file or file-like object)
export_params=True) # store the trained parameter weights inside the model file
就是这样:
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number
# Initialize model with the pretrained weights
torch_model.load_state_dict(model_zoo.load_url(model_url))
# set the train mode to false since we will only run the forward pass.
torch_model.train(False)
两种都在载入预训练权重, 直接加载到搭建好的网络上. 对于我手头有的已经训练好的模型, 似乎并不符合这样的条件.
导出整体模型
最后采用尽可能模仿上面的例子代码的策略, 将整个网络完整的导出(torch.save(model)), 然后再仿照上面那样, 将完整的网络加载(torch.load())到转换的代码中, 照猫画虎, 以进一步处理.
这里也很大程度上受到这里的启发: https://github.com/akirasosa/mobile-semantic-segmentation
本来想尝试使用之前找到的不论效果还是性能都很强的R3Net进行转换, 可是, 出于作者搭建网络使用的特殊手段, 加上pickle和onnx的限制, 这个尝试没有奏效, 只好转回头使用之前学习的DHS-Net的代码, 因为它的实现是基于VGG的, 里面的搭建的网络也是需要修改来符合onnx的要求, 主要是更改上采样操作为转置卷积(也就是分数步长卷积, 这里顺带温习了下pytorch里的nn.ConvTranspose2d()的计算方式), 因为pytorch的上采样在onnx转换过程中有很多的问题, 特别麻烦, 外加上修改最大池化的一个参数(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False)的参数ceil_mode改为ceil_mode=False, 这里参考自前面的知乎专栏的那篇文章), 这样终于可以转换了, 为了方便和快速的测试, 我只是训练了一个epoch, 就直接导出模型, 这次终于可以顺利的torch.save()了.
filename_opti = ('%s/model-best.pth' % check_root_model)
torch.save(model, filename_opti)
之后便利用类似的代码进行了书写.
IMG_SIZE = 224
TMP_ONNX = 'cache/onnx/DHSNet.onnx'
MODEL_PATH = 'cache/opti/total-opti-current.pth'
# Convert to ONNX once
model = torch.load(MODEL_PATH).cuda()
model.train(False)
x = Variable(torch.randn(1, 3, 224, 224), requires_grad=True).cuda()
torch_out = torch.onnx._export(model, x, TMP_ONNX, export_params=True)
caffe2模型转换
载入模型后, 便可以开始转换了, 这里需要安装caffe2, 官方推荐直接conda安装pytorch1每夜版即可, 会自动安装好依赖.
说起来这个conda, 就让我又爱又恨, 用它装pytorch从这里可以看出来, 确实不错, 对系统自身的环境没有太多的破坏, 可是用它装tensorflow-gpu的时候, 却是要自动把conda源里的cuda, cudnn工具包都给带上, 有时候似乎会破坏掉系统自身装载的cuda环境(? 不太肯定, 反正现在我不这样装, 直接上pip装, 干净又快速).
之后的代码中, 主要的问题也就是tensor的cpu/cuda, 或者numpy的转换的问题了. 多尝试一下, 输出下类型就可以看到了.
# Let's also save the init_net and predict_net to a file that we will later use for running them on mobile
with open('./cache/model_mobile/init_net.pb', "wb") as fopen:
fopen.write(init_net.SerializeToString())
with open('./cache/model_mobile/predict_net.pb', "wb") as fopen:
fopen.write(predict_net.SerializeToString())
预处理的补充
这里记录下, 查看pytorch的tensor的形状使用tensor.size()方法, 查看numpy数组的形状则使用numpy数组的adarray.shape方法, 而对于PIL(from PIL import Image)读取的Image对象而言, 使用Image.size查看, 而且, 这里只会显示宽和高的长度, 而且Image的对象, 是三维, 在于pytorch的tensor转换的时候, 或者输入网络的时候, 要注意添加维度, 而且要调整通道位置(img = img.transpose(2, 0, 1)).
由于网络保存的部分中, 只涉及到了网络的结构内的部分, 对于数据的预处理的部分并不涉及, 所以说要想真正的利用网络, 还得调整真实的输入, 来作为更适合网络的数据输入.
要注意, 这里针对导出的模型的相关测试, 程实际上是按照测试网络的流程来的.
# load the resized image and convert it to Ybr format
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = Image.open("./data/ILSVRC2012_test_00000004_224x224.jpg")
img = np.array(img)
img = img.astype(np.float64) / 255
img -= mean
img /= std
img = img.transpose(2, 0, 1)
安卓的尝试
首先安卓环境的配置就折腾了好久, 一堆破事, 真实的生产开发, 真心不易啊...
这里最终还是失败了, 因为对于安卓的代码是在是不熟悉, 最起码的基础认知都不足, 只有这先前学习Java的一点皮毛知识, 根本不足以二次开发. 也就跑了跑几个完整的demo而已.
AiCamera
这个跑通了, 但是这是个分类网络的例子, 对于我们要做的分割的任务而言, 有很多细节不一样.
- 输入有差异: 比赛要求的是若是提交apk, 那么要求可以从相册读取图片, 而例子是从摄像头读取的视频数据流. 虽然也处理的是视频帧, 但是要我们再次补充的内容又多了起来, 还是那句话, android一窍不通.
- 输出有差异: 自我猜测, 比赛为了测评, 输出必然也要输出到相册里, 不然何来测评一说?
AICamera-Style-Transfer
这个例子我们参考了一下, 只是因为它的任务是对摄像头视频流数据风格迁移, 而且会直接回显到手机屏幕上, 这里我们主要是想初步实现对于我们网络模型安卓迁移的测试, 在第一个例子的基础上能否实现初步的摄像头视频流的分割, 然后下一步再进一步满足比赛要求.
可是, 尝试失败了. 虽然AS打包成了APK, 手机也安装上了, 可是莫名的, 在"loading..."中便闪退了...
JejuNet
这个例子很给力, 但是使用的是tensorflowlite, 虽然可以用, 能够实现下面的效果, 可是, 不会改.

而且是量化网络, 准确率还是有待提升.
最后的思考
最后还是要思考一下的, 做个总结.
没经验
吃就吃在没经验的亏上了, 都是初次接触, 之前没怎么接触过安卓, 主要是安卓的开发对于电脑的配置要求太高了, 自己的笔记本根本不够玩的. 也就没有接触过了.
外加上之前的研究学习, 主要是在学术的环境下搞得, 和实际的生产还有很大的距离, 科研与生产的分离, 这对于深度学习这一实际上更偏重实践的领域来说, 有些时候是尤为致命的. 关键时刻下不去手, 这多么无奈, 科学技术无法转化为实实在在的生产力, 忽然有些如梦一般的缥缈.
当然, 最关键的还是, 没有仔细分析赛方的需求, 没有完全思考清楚, 直接就开干了, 这个鲁莽的毛病, 还是没有改掉, 浪费时间不说, 也无助于实际的进度. 赛方的说明含糊, 应该问清楚.
若是担心时间, 那更应该看清楚要求, 切莫随意下手. 比赛说明里只是说要提交一个打包好的应用, 把环境, 依赖什么都处理好, 但是不一定是安卓apk呀, 可以有很多的形式, 但是这也只是最后的一点额外的辅助而已, 重点是模型的性能和效率呢.
莫忘初心, 方得始终. 为什么我想到的是这句.
下一步
基本上就定了还是使用R3Net, 只能是进一步的细节修改了, 换换后面的循环结构了, 改改连接什么的.
我准备再开始看论文, 学姐的论文可以看看, 似乎提出了一种很不错的后处理的方法, 效果提升很明显, 需要研究下.
pickle和onnx的限制
pytorch的torch.save(model)保存模型的时候, 模型架构的代码里不能使用一些特殊的构建形式, R3Net的ResNeXt结构就用了, 主要是一些lambda结构, 虽然不是太清楚, 但是一般的搭建手段都是可以的.
onnx对于pytorch的支持的操作, 在我的转化中, 主要是最大池化和上采样的问题, 前者可以修改ceil_mode为False, 后者则建议修改为转置卷积, 避免不必要的麻烦. 可见"导出整体模型"小节的描述.
打包apk安装
这里主要是用release版本构建的apk.
未签名的apk在我的mi 8se (android 8.1)上不能安装, 会解析失败, 需要签名, AS的签名的生成也很简单, 和生成apk在同一级上, 有生成的选项.
生产与学术之Pytorch模型导出为安卓Apk尝试记录的更多相关文章
- 使用C++调用并部署pytorch模型
1.背景(Background) 上图显示了目前深度学习模型在生产环境中的方法,本文仅探讨如何部署pytorch模型! 至于为什么要用C++调用pytorch模型,其目的在于:使用C++及多线程可以加 ...
- Revit二次开发-BIM模型导出
最近一个月一直在研究Revit二次开发-BIM模型的导出,在网上找了很多相关资料学习.下面简单介绍一下我最近做的这个BIM模型的导出功能. 开始尝试使用Revit2015的样例程序里提供的读取模型几何 ...
- TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式
本文承接上文 TensorFlow-slim 训练 CNN 分类模型(续),阐述通过 tf.contrib.slim 的函数 slim.learning.train 训练的模型,怎么通过人为的加入数据 ...
- (原)torch模型转pytorch模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...
- 使用C++调用pytorch模型(Linux)
前言 模型转换思路通常为: Pytorch -> ONNX -> TensorRT Pytorch -> ONNX -> TVM Pytorch -> 转换工具 -> ...
- DEX-6-caffe模型转成pytorch模型办法
在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...
- PyTorch模型加载与保存的最佳实践
一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...
- 资源分享 | PyTea:不用运行代码,静态分析pytorch模型的错误
前言 本文介绍一个Pytorch模型的静态分析器 PyTea,它不需要运行代码,即可在几秒钟之内扫描分析出模型中的张量形状错误.文末附使用方法. 本文转载自机器之心 编辑:CV技 ...
- 从零搭建Pytorch模型教程(三)搭建Transformer网络
前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...
随机推荐
- JAVA WEB项目中开启流量控制Filter
Flow Control:控流的概念 主要是用来限定server所能承载的最大(高并发)流量峰值,以免在峰值是Server过载而宕机,对于WEB系统而言 通常是分布式部署,如果请求并发量很大,会导致整 ...
- 【HDFS API编程】副本系数深度剖析
上一节我们使用Java API操作HDFS文件系统创建了文件a.txt并写入了hello hadoop(回顾:https://www.cnblogs.com/Liuyt-61/p/10739018.h ...
- spring boot 接口返回值去掉为null的字段
现在项目都是前后端分离的,返回的数据都是使用json,但有些接口的返回值存在 null或者"",这种字段不仅影响理解,还浪费带宽,需要统一做一下处理,不返回空字段,或者把NULL转 ...
- Mac端StartUML的安装和破解
**本人安装的是StarUML-3.0.1版本** 一.下载与安装 1. 从官方网站下载,网址:http://staruml.io/ 2. dmg文件下载完成后,双击安装. 二.破解 1. 安装npm ...
- OO第十五次作业
一.测试与正确性论证的效果差异 测试和正确性论证分别是从理论和实践两个角度去规范程序的正确性的,我认为其主要的区别在于对于程序透明度的需求上, 测试作为一种实践手段,他的实施的要求是比较低的:在完全了 ...
- Vsftpd服务重启、暂停命令
VSFTP是一个基于GPL发布的类Unix系统上使用的FTP服务器软件,它的全称是Very Secure FTP 从此名称可以看出来,编制者的初衷是代码的安全. 在使用Vsftp服务是经常需要启动.停 ...
- cookies相关概念
1.什么是Cookie Cookie实际上是一小段的文本信息.客户端请求服务器,如果服务器需要记录该用户状态,就使用response向客户端浏览器颁发一个Cookie.客户端浏览器会把Cookie保存 ...
- json与csv的基础用与法
json库是处理json格式的python标准库 有两个过程: 编码(encoding):将python数据类型转换为json格式的过程 解码(decoding):从json格式中解析数据得到的pyt ...
- Hadoop HDFS常用命令
1.查看hdfs文件目录 hadoop fs -ls / 2.上传文件 hadoop fs -put 文件路径 目标路径 在浏览器查看:namenodeIP:50070 3.下载文件 hadoop f ...
- U3D学习资料收集
1,风宇冲的博客 2,gkEngine 3,@浅墨_毛星云 4,聊聊引擎底层如何实现BRDF渲染算法