将pb模型参数提取转成torch模型
1 import tensorflow as tf
2 import onnx
3 import onnxsim
4 import numpy as np
5 import torch
6 from model.facedetector_model import mobilenetv2_yolov3
7
8 #提取pb模型中的参数
9 def extract_params_from_pb():
10 constant_values = {}
11 with tf.compat.v1.Session() as sess:
12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
13 graph_def = tf.compat.v1.GraphDef()
14 graph_def.ParseFromString(f.read())
15 sess.graph.as_default()
16 tf.import_graph_def(graph_def, name='')
17 # # input
18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
19 # # output
20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
21 # sess.run(output, feed_dict={'input/input_data:0': inputimage})
22
23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
24 for constant_op in constant_ops:
25 if constant_op.op_def.name == "Const":
26 if "Shape" in constant_op.name or "pred" in constant_op.name:
27 continue
28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
29 return constant_values
30
31 #过滤提取出来的params
32 def filter_params(constant_values):
33 total = 0
34 prompt = []
35 res = {}
36 forbidden = ['shape','stack']
37
38 for k,v in constant_values.items():
39 # filtering some by checking ndim and name
40 if v.ndim<1: continue
41 if v.ndim==1:
42 token = k.split(r'/')[-1]
43 flag = False
44 for word in forbidden:
45 if token.find(word)!=-1:
46 flag = True
47 break
48 if flag:
49 continue
50
51 shape = v.shape
52 cnt = 1
53 for dim in shape:
54 cnt *= dim
55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
56 res[k] = v
57 print(prompt[-1])
58 total += cnt
59 prompt.append('totaling {}'.format(total))
60 # print(prompt[-1])
61 return res
62
63 #将Tensorflow的张量转换成PyTorch的张量
64 def trans_tensor_pb2pth(k,a):
65
66 v = tf.convert_to_tensor(a).numpy()
67 # tensorflow weights to pytorch weights
68 if len(v.shape) == 4:
69 if "depthwise_weights" in k:#防止深度可分离卷积
70 return np.ascontiguousarray(v.transpose(2,3,0,1))
71 return np.ascontiguousarray(v.transpose(3,2,0,1))
72 elif len(v.shape) == 2:
73 return np.ascontiguousarray(v.transpose())
74 return v
75
76 #将pb的对应params名字转换为pth对应参数名
77 def trans_name_pb2pth(trans_weights):
78 model_dict = {}
79 for name,para in trans_weights.items():
80 name = name.replace('/',".")
81
82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
83 name = name.replace('weights',"0.weight")
84 name = name.replace('BatchNorm',"1")
85 name = name.replace('gamma',"weight")
86 name = name.replace('beta',"bias")
87 name = name.replace('moving_mean',"running_mean")
88 name = name.replace('moving_variance',"running_var")
89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
90 name = name.replace('depthwise.',"0.")
91 name = name.replace('project',"1")
92 name = name.replace('depthwise_weights',"0.weight")
93 name = name.replace('weights',"0.weight")
94 name = name.replace('BatchNorm',"1")
95 name = name.replace('gamma',"weight")
96 name = name.replace('beta',"bias")
97 name = name.replace('moving_mean',"running_mean")
98 name = name.replace('moving_variance',"running_var")
99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100 name = name.replace('expand.',"0.")
101 name = name.replace('depthwise.',"1.")
102 name = name.replace('project',"2")
103 name = name.replace('depthwise_weights',"0.weight")
104 name = name.replace('weights',"0.weight")
105 name = name.replace('BatchNorm',"1")
106 name = name.replace('gamma',"weight")
107 name = name.replace('beta',"bias")
108 name = name.replace('moving_mean',"running_mean")
109 name = name.replace('moving_variance',"running_var")
110 elif "yolo-v3" in name:
111 if "bbox" in name:
112 continue
113 name = name.replace('yolo-v3',"yolo_v3")
114 name = name.replace('weight',"0.weight")
115 name = name.replace('kernel',"weight")
116 name = name.replace('batch_normalization',"1")
117 name = name.replace('gamma',"weight")
118 name = name.replace('beta',"bias")
119 name = name.replace('moving_mean',"running_mean")
120 name = name.replace('moving_variance',"running_var")
121 print(name)
122 model_dict[name] = torch.Tensor(para)
123 return model_dict
124
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127 constant_values = extract_params_from_pb()
128 TF_weights = filter_params(constant_values)
129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130
131 #创建pytorch模型
132 PyTorchModel = mobilenetv2_yolov3()
133 model_dict = trans_name_pb2pth(trans_weights)
134 # model_dict = PyTorchModel.state_dict()
135 # for name in model_dict.keys():
136 # print(name)
137 PyTorchModel.load_state_dict(model_dict)
138 PyTorchModel.cuda().eval()
139 dummy_input = torch.rand(1,1,224,224,device="cuda").float()
140 # out = PyTorchModel(dummy_input)
141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142 print("====> Simplifying...")
143 model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144 onnx.save(model_opt, 'P3mNet_sim.onnx')
145 print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()
将pb模型参数提取转成torch模型的更多相关文章
- 利用反射将Datatable、SqlDataReader转换成List模型
1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...
- (原)torch模型转pytorch模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...
- 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型
在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...
- 【tensorflow-v2.0】如何将模型转换成tflite模型
前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...
- DEX-6-caffe模型转成pytorch模型办法
在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...
- 使用C#语言,将DataTable 转换成域模型
DataTable dt = SqlHelper.Query(strQuery); ) * size).Take(pagesize); List<Model> listData = new ...
- PB之取下来列修改后的值(AcceptText)
AcceptText()功能 将“漂浮”在数据窗口控件上编辑框的内容放入到数据窗口控件的当前项中(主缓区中).在将数据放入到当前项之前,编辑框中的数据必须通过有效性规则检查语法 dwcontrol. ...
- pytorch1.0 用torch script导出模型
python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...
- MxNet 模型转Tensorflow pb模型
用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...
- iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型
在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...
随机推荐
- 《HelloGitHub》第 83 期
兴趣是最好的老师,HelloGitHub 让你对编程感兴趣! 简介 HelloGitHub 分享 GitHub 上有趣.入门级的开源项目. https://github.com/521xueweiha ...
- KingbaseES DBLink 介绍
DBLink 扩展插件功能与 Kingbase_FDW 类似,用于远程访问KingbaseES 数据库.相比于Kingbase_FDW,DBLink 功能更强大,可以执行DML,还可以通过 begin ...
- python编辑excel表格文件的简单方法练习
一.创建一个Excel文件from openpyxl import Workbook #需要用到openpyxl模块来操作Excel文件.openpyxl需要先安装.#实例化对象wb = Workbo ...
- android 实现检测版本,下载apk更新(附源码)
其实这不是什么难事了,都有热更新的技术了,只是记录一下,大神勿嘲笑. 先说下思路,首先要有更新的接口,只要进入app,就监测一下接口,是否更新,更新的话,检测本地版本是否低于接口返回的版本,低的话,就 ...
- libuv 网络库设计概览译
设计概览 libuv 是一种支持跨平台的网络库,最初是为了NodeJS作为某个模块实现的,主要基于事件驱动的I/O 模型设计的. 这个库不仅仅对不同的I/O polling 机制提供简单的抽象. ha ...
- 退役*CPCer的找实习总结
从2月底开始到今天,我终于拿到了第一个也是唯一一个offer(字节跳动).找实习的过程告一段落,所以想记录一下这段时间的经历. 最开始找$meopass$学长内推了小马智行,很快就接到了面试通知(再次 ...
- python 使用异常来中断/暂停线程
""""python 使用异常来中断/暂停线程h_thread 线程句柄stoptype 线程停止类型,返回1则正常中断了线程""" ...
- 软件测试肖sir__多线程、多进程、多协程
Python并发编程有三种方式: 1.多线程Thread(threading)(读音:思来d,丁).多进程Process(multiprocessing).多协程Coroutine(asyncio) ...
- reset slave
reset slave 所有中继日志文件都被删除,即使它们还没有被复制 SQL 线程完全执行. reset slave all 所有中继日志文件都被删除,它会清除连接参数(需要重新change mas ...
- Bugku-不可破译的密码[wp]
一 题目分析 flag.txt cipher.txt (1)密码表形式和维吉尼亚密码一样 (2)看到504Q0304 很容易想到 504B0304 Zip文件头. 二 解题步骤 2.1 解密密文 根据 ...