MxNet 模型转Tensorflow pb模型
用mmdnn实现模型转换
参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af
- 安装mmdnn
pip install mmdnn
- 准备好mxnet模型的.json文件和.params文件, 以InsightFace MxNet r50为例 https://github.com/deepinsight/insightface
- 用mmdnn运行命令行
python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-.params -d resnet50 --inputShape ,,
会生成resnet50.json(可视化文件) resnet50.npy(权重参数) resnet50.pb(网络结构)三个文件。
- 用mmdnn运行命令行
python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath resnet50.pb --IRWeightPath resnet50.npy --dstModelPath tf_resnet50.py
生成tf_resnet50.py文件,可以调用tf_resnet50.py中的KitModel函数加载npy权重参数重新生成原网络框架。
打开tf_resnet.py文件,修改load_weights()中的代码 (tensorflow=1.14.0报错)
try:
weights_dict = np.load(weight_file).item()
except:
weights_dict = np.load(weight_file, encoding='bytes').item()改为
try:
weights_dict = np.load(weight_file, allow_pickle=True).item()
except:
weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()基于resnet50.npy和tf_resnet50.py文件,固化参数,生成PB文件:
import tensorflow as tf
import tf_resnet50 as tf_fun
def netWork():
model=tf_fun.KitModel("./resnet50.npy")
return model
def freeze_graph(output_graph):
output_node_names = "output"
data,fc1=netWork()
fc1=tf.identity(fc1,name="output") graph = tf.get_default_graph() # 獲得默認的圖
input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定
sess=sess,
input_graph_def=input_graph_def, # 等於:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多個輸出節點,以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化輸出 if __name__ == '__main__':
freeze_graph("frozen_insightface_r50.pb")
print("finish!")- 采用tensorflow的post-train quantization离线量化方法(有一定的精度损失)转换成tflite模型,从而完成端侧的模型部署:
import tensorflow as tf convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_insightface_r50.pb",input_arrays=["data"],output_arrays=["output"],
input_shapes={"data":[1,112,112,3]})
convert.post_training_quantize=True
tflite_model=convert.convert()
open("quantized_insightface_r50.tflite","wb").write(tflite_model)
print("finish!")
MxNet 模型转Tensorflow pb模型的更多相关文章
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- Problem after converting keras model into Tensorflow pb - 将keras模型转换为Tensorflow pb后的问题
I'm using keras 2.1.* with tensorflow 1.13.* backend. I save my model during training with .h5 forma ...
- 查看tensorflow Pb模型所有层的名字
代码如下: import tensorflow as tf def get_all_layernames(): """get all layers name"& ...
- TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人
简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...
- TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。
简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...
- tensorflow模型ckpt转pb以及其遇到的问题
使用tensorflow训练模型,ckpt作为tensorflow训练生成的模型,可以在tensorflow内部使用.但是如果想要永久保存,最好将其导出成pb的形式. tensorflow已经准备好c ...
- TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式
本文承接上文 TensorFlow-slim 训练 CNN 分类模型(续),阐述通过 tf.contrib.slim 的函数 slim.learning.train 训练的模型,怎么通过人为的加入数据 ...
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛
tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls0 1 10 11 12 13 14 15 ...
随机推荐
- matlab 矢量化编程(一)—— 计算 AUC
AUC = sum( (Y(2:end)+Y(1:end-1))/2 .* (X(2:end) - X(1:end-1)) X 和 Y 均是向量: Y(2:end) - Y(1:end-1),是 Y( ...
- Easyui Tab刷新
Easyui Tab刷新: function refreshTab(title){ var tab = $('#id').tab('getTab',title); $('#id').tab('upda ...
- SendMessage函数与MSDN系统预定义消息
SendMessage function https://msdn.microsoft.com/en-us/library/windows/desktop/ms644950%28v=vs.85%29. ...
- 如何自定义WPF项目的Main函数
原文:如何自定义WPF项目的Main函数 与Winform项目不同,WPF项目的Main函数在项目生成的时候,系统自动在后台为我们生成.根据项目生成方式的不同,其文件位于obj/Debug/App.g ...
- vmware linux无法正常上网
不知道自己怎么搞的整的vmware里面的fedora 12 不能正常上网,但是在宿主机上ping XXX,是正常的.当service network restart 的时候提示MAC有问题.网上百度了 ...
- css3 hover平滑过渡效果,鼠标经过元素,背景渐隐渐现效果
下面实例,演示,鼠标经过时,改变div宽度,平滑改变,带动画 div { width:100px; height:100px; background:blue; transition:width 2s ...
- DSP Builder 12.0安装及crack方法
在安装dsp_builder之前请确保已安装所需要的matlab版本 在此之前我已经安装了matlab R2011a,下面安装dsp builder 下面就是破解了,因为12.0的版本刚出,还没有相应 ...
- Win8 Metro(C#)数字图像处理--2.70修正后的阿尔法滤波器
原文:Win8 Metro(C#)数字图像处理--2.70修正后的阿尔法滤波器 /// <summary> /// Alpha filter. /// </summary> / ...
- Win8 Metro(C#)数字图像处理--2.56简单统计法图像二值化
原文:Win8 Metro(C#)数字图像处理--2.56简单统计法图像二值化 [函数名称] 简单统计法图像二值化 WriteableBitmap StatisticalThSegment(Wr ...
- dotnet core 跨平台编译发布
vs2017 建立的项目,在项目目录 ,执行 dotnet publish -r ubuntu.15.04-x64 dotnet publish -r linux-x64 dotnet publish ...