1 import tensorflow as tf
2 import onnx
3 import onnxsim
4 import numpy as np
5 import torch
6 from model.facedetector_model import mobilenetv2_yolov3
7
8 #提取pb模型中的参数
9 def extract_params_from_pb():
10 constant_values = {}
11 with tf.compat.v1.Session() as sess:
12 with tf.io.gfile.GFile('model/FaceDetector.pb', 'rb') as f:
13 graph_def = tf.compat.v1.GraphDef()
14 graph_def.ParseFromString(f.read())
15 sess.graph.as_default()
16 tf.import_graph_def(graph_def, name='')
17 # # input
18 # input_x = sess.graph.get_tensor_by_name('input/input_data:0')
19 # # output
20 # output = sess.graph.get_tensor_by_name('pred_bbox/Reshape:0')
21 # sess.run(output, feed_dict={'input/input_data:0': inputimage})
22
23 constant_ops = [op for op in sess.graph.get_operations()]#[op for op in sess.graph.get_operations() if op.type == "Const"]
24 for constant_op in constant_ops:
25 if constant_op.op_def.name == "Const":
26 if "Shape" in constant_op.name or "pred" in constant_op.name:
27 continue
28 constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
29 return constant_values
30
31 #过滤提取出来的params
32 def filter_params(constant_values):
33 total = 0
34 prompt = []
35 res = {}
36 forbidden = ['shape','stack']
37
38 for k,v in constant_values.items():
39 # filtering some by checking ndim and name
40 if v.ndim<1: continue
41 if v.ndim==1:
42 token = k.split(r'/')[-1]
43 flag = False
44 for word in forbidden:
45 if token.find(word)!=-1:
46 flag = True
47 break
48 if flag:
49 continue
50
51 shape = v.shape
52 cnt = 1
53 for dim in shape:
54 cnt *= dim
55 prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
56 res[k] = v
57 print(prompt[-1])
58 total += cnt
59 prompt.append('totaling {}'.format(total))
60 # print(prompt[-1])
61 return res
62
63 #将Tensorflow的张量转换成PyTorch的张量
64 def trans_tensor_pb2pth(k,a):
65
66 v = tf.convert_to_tensor(a).numpy()
67 # tensorflow weights to pytorch weights
68 if len(v.shape) == 4:
69 if "depthwise_weights" in k:#防止深度可分离卷积
70 return np.ascontiguousarray(v.transpose(2,3,0,1))
71 return np.ascontiguousarray(v.transpose(3,2,0,1))
72 elif len(v.shape) == 2:
73 return np.ascontiguousarray(v.transpose())
74 return v
75
76 #将pb的对应params名字转换为pth对应参数名
77 def trans_name_pb2pth(trans_weights):
78 model_dict = {}
79 for name,para in trans_weights.items():
80 name = name.replace('/',".")
81
82 if "MobilenetV2.Conv" in name:#处理MobilenetV2.Conv
83 name = name.replace('weights',"0.weight")
84 name = name.replace('BatchNorm',"1")
85 name = name.replace('gamma',"weight")
86 name = name.replace('beta',"bias")
87 name = name.replace('moving_mean',"running_mean")
88 name = name.replace('moving_variance',"running_var")
89 elif "MobilenetV2.expanded_conv." in name:#处理MobilenetV2.expanded_conv.
90 name = name.replace('depthwise.',"0.")
91 name = name.replace('project',"1")
92 name = name.replace('depthwise_weights',"0.weight")
93 name = name.replace('weights',"0.weight")
94 name = name.replace('BatchNorm',"1")
95 name = name.replace('gamma',"weight")
96 name = name.replace('beta',"bias")
97 name = name.replace('moving_mean',"running_mean")
98 name = name.replace('moving_variance',"running_var")
99 elif "MobilenetV2.expanded_conv_" in name:#处理MobilenetV2.expanded_conv_*
100 name = name.replace('expand.',"0.")
101 name = name.replace('depthwise.',"1.")
102 name = name.replace('project',"2")
103 name = name.replace('depthwise_weights',"0.weight")
104 name = name.replace('weights',"0.weight")
105 name = name.replace('BatchNorm',"1")
106 name = name.replace('gamma',"weight")
107 name = name.replace('beta',"bias")
108 name = name.replace('moving_mean',"running_mean")
109 name = name.replace('moving_variance',"running_var")
110 elif "yolo-v3" in name:
111 if "bbox" in name:
112 continue
113 name = name.replace('yolo-v3',"yolo_v3")
114 name = name.replace('weight',"0.weight")
115 name = name.replace('kernel',"weight")
116 name = name.replace('batch_normalization',"1")
117 name = name.replace('gamma',"weight")
118 name = name.replace('beta',"bias")
119 name = name.replace('moving_mean',"running_mean")
120 name = name.replace('moving_variance',"running_var")
121 print(name)
122 model_dict[name] = torch.Tensor(para)
123 return model_dict
124
125 #将pb参数copy给pth模型
126 def copy_pbParams2pthParams():
127 constant_values = extract_params_from_pb()
128 TF_weights = filter_params(constant_values)
129 trans_weights = {k:trans_tensor_pb2pth(k,v) for (k, v) in TF_weights.items() }
130
131 #创建pytorch模型
132 PyTorchModel = mobilenetv2_yolov3()
133 model_dict = trans_name_pb2pth(trans_weights)
134 # model_dict = PyTorchModel.state_dict()
135 # for name in model_dict.keys():
136 # print(name)
137 PyTorchModel.load_state_dict(model_dict)
138 PyTorchModel.cuda().eval()
139 dummy_input = torch.rand(1,1,224,224,device="cuda").float()
140 # out = PyTorchModel(dummy_input)
141 torch.onnx.export(PyTorchModel,dummy_input,"P3mNet.onnx",verbose = True,opset_version = 11)
142 print("====> Simplifying...")
143 model_opt,_ = onnxsim.simplify("P3mNet.onnx")
144 onnx.save(model_opt, 'P3mNet_sim.onnx')
145 print("onnx model simplify Ok!")
146 copy_pbParams2pthParams()

将pb模型参数提取转成torch模型的更多相关文章

  1. 利用反射将Datatable、SqlDataReader转换成List模型

    1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...

  2. (原)torch模型转pytorch模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7839263.html 目前使用的torch模型转pytorch模型的程序为: https://gith ...

  3. 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型

    在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...

  4. 【tensorflow-v2.0】如何将模型转换成tflite模型

    前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...

  5. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  6. 使用C#语言,将DataTable 转换成域模型

    DataTable dt = SqlHelper.Query(strQuery); ) * size).Take(pagesize); List<Model> listData = new ...

  7. PB之取下来列修改后的值(AcceptText)

    AcceptText()功能 将“漂浮”在数据窗口控件上编辑框的内容放入到数据窗口控件的当前项中(主缓区中).在将数据放入到当前项之前,编辑框中的数据必须通过有效性规则检查语法  dwcontrol. ...

  8. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

  9. MxNet 模型转Tensorflow pb模型

    用mmdnn实现模型转换 参考链接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af 安装mmdnn pip install mmdnn 准备好mx ...

  10. iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型

    在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...

随机推荐

  1. Postgresql 二进制字符串函数和操作符

    1.SQL 二进制字符串函数和操作符 函数 返回类型 描述 例子 结果 string || string bytea 字符串连接 E'\\\\Post'::bytea || E'\\047gres\\ ...

  2. LeetCode-825 适龄的朋友

    来源:力扣(LeetCode)链接:https://leetcode-cn.com/problems/friends-of-appropriate-ages 题目描述 在社交媒体网站上有 n 个用户. ...

  3. day09-MyBatis缓存

    MyBatis缓存 mybatis – MyBatis 3 | cache MyBatis 一级缓存全详解(一) MyBatis 内置了一个强大的事务性查询缓存机制,它可以非常方便地配置和定制. 为了 ...

  4. Anaconda 环境中安装OpenCV (cv2)

    1.使用Anaconda 的对应环境,查看环境中的Python版本号 (1)使用Anaconda 查看存在的环境:conda info --env (2)激活环境:conda activate XXX ...

  5. Vicinity Vision Transformer概述

    0.前言 相关资料: arxiv github 论文解读 论文基本信息: 发表时间:arxiv2022(2022.6.21) 1.针对的问题 视觉transformer计算复杂度和内存占用都是二次的, ...

  6. 如何简化跨网络安全域的文件发送流程,大幅降低IT人员工作量?

    为什么要做安全域的隔离? 随着企业数字化转型的逐步深入,企业投入了大量资源进行信息系统建设,信息化程度日益提升.在这一过程中,企业也越来越重视核心数据资产的保护,数据资产的安全防护成为企业面临的重大挑 ...

  7. 02_IntelliJ IDEA常用快捷键

    [常见快捷键] Ctrl+Shift + Enter 语句完成   "!" 否定完成 输入表达式时按 "!"键 Ctrl+E 最近的文件   Ctrl+Shif ...

  8. oracle 行转列,动态年份,月份列。已解决!

    -----------------存储过程包体----------- procedure GetComparativeAnalysisTB(p_StartTime varchar2, ----开始时间 ...

  9. SpringCloudBus实现配置文件动态更新

    前言 在SpringCloud之配置中心(config)的使用的基础上加上SpringCloudBus实现配置文件动态更新 在此之前需要修改版本,否则会出现"Endpoint ID 'bus ...

  10. 073_SFDC Limit

    我们在开发的过程中,应多注意一些系统自身的限制,以及遇到此类问题的应对措施: Description Synchronous Limit Asynchronous Limit Total number ...