导出pb模型之后测试的python代码
链接:https://blog.csdn.net/thriving_fcl/article/details/75213361
saved_model模块主要用于TensorFlow Serving。TF Serving是一个将训练好的模型部署至生产环境的系统,主要的优点在于可以保持Server端与API不变的情况下,部署新的算法或进行试验,同时还有很高的性能。
保持Server端与API不变有什么好处呢?有很多好处,我只从我体会的一个方面举例子说明一下,比如我们需要部署一个文本分类模型,那么输入和输出是可以确定的,输入文本,输出各类别的概率或类别标签。为了得到较好的效果,我们可能想尝试很多不同的模型,CNN,RNN,RCNN等,这些模型训练好保存下来以后,在inference阶段需要重新载入这些模型,我们希望的是inference的代码有一份就好,也就是使用新模型的时候不需要针对新模型来修改inference的代码。这应该如何实现呢?
在TensorFlow 模型保存/载入的两种方法中总结过。
1. 仅用Saver来保存/载入变量。这个方法显然不行,仅保存变量就必须在inference的时候重新定义Graph(定义模型),这样不同的模型代码肯定要修改。即使同一种模型,参数变化了,也需要在代码中有所体现,至少需要一个配置文件来同步,这样就很繁琐了。
2. 使用tf.train.import_meta_graph导入graph信息并创建Saver, 再使用Saver restore变量。相比第一种,不需要重新定义模型,但是为了从graph中找到输入输出的tensor,还是得用graph.get_tensor_by_name()来获取,也就是还需要知道在定义模型阶段所赋予这些tensor的名字。如果创建各模型的代码都是同一个人完成的,还相对好控制,强制这些输入输出的命名都一致即可。如果是不同的开发者,要在创建模型阶段就强制tensor的命名一致就比较困难了。这样就不得不再维护一个配置文件,将需要获取的tensor名称写入,然后从配置文件中读取该参数。
经过上面的分析发现,要实现inference的代码统一,使用原来的方法也是可以的,只不过TensorFlow官方提供了更好的方法,并且这个方法不仅仅是解决这个问题,所以还是得学习使用saved_model这个模块。
saved_model 保存/载入模型
先列出会用到的API
class tf.saved_model.builder.SavedModelBuilder
# 初始化方法
__init__(export_dir)
# 导入graph与变量信息
add_meta_graph_and_variables(
sess,
tags,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None,
clear_devices=False,
main_op=None
)
# 载入保存好的模型
tf.saved_model.loader.load(
sess,
tags,
export_dir,
**saver_kwargs
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
(1) 最简单的场景,只是保存/载入模型
保存
要保存一个已经训练好的模型,使用下面三行代码就可以了。
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(sess, ['tag_string'])
builder.save()
1
2
3
首先构造SavedModelBuilder对象,初始化方法只需要传入用于保存模型的目录名,目录不用预先创建。
add_meta_graph_and_variables方法导入graph的信息以及变量,这个方法假设变量都已经初始化好了,对于每个SavedModelBuilder这个方法一定要执行一次用于导入第一个meta graph。
第一个参数传入当前的session,包含了graph的结构与所有变量。
第二个参数是给当前需要保存的meta graph一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel这样的错。标签也可以选用系统定义好的参数,如tf.saved_model.tag_constants.SERVING与tf.saved_model.tag_constants.TRAINING。
save方法就是将模型序列化到指定目录底下。
保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。
载入
使用tf.saved_model.loader.load方法就可以载入模型。如
meta_graph_def = tf.saved_model.loader.load(sess, ['tag_string'], saved_model_dir)
1
第一个参数就是当前的session,第二个参数是在保存的时候定义的meta graph的标签,标签一致才能找到对应的meta graph。第三个参数就是模型保存的目录。
load完以后,也是从sess对应的graph中获取需要的tensor来inference。如
x = sess.graph.get_tensor_by_name('input_x:0')
y = sess.graph.get_tensor_by_name('predict_y:0')
# 实际的待inference的样本
_x = ...
sess.run(y, feed_dict={x: _x})
1
2
3
4
5
6
这样和之前的第二种方法一样,也是要知道tensor的name。那么如何可以在不知道tensor name的情况下使用呢? 那就需要给add_meta_graph_and_variables方法传入第三个参数,signature_def_map。
(2) 使用SignatureDef
关于SignatureDef我的理解是,它定义了一些协议,对我们所需的信息进行封装,我们根据这套协议来获取信息,从而实现创建与使用模型的解耦。SignatureDef的结构以及相关详细的文档在:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md
相关API
# 构建signature
tf.saved_model.signature_def_utils.build_signature_def(
inputs=None,
outputs=None,
method_name=None
)
# 构建tensor info
tf.saved_model.utils.build_tensor_info(tensor)
1
2
3
4
5
6
7
8
9
SignatureDef,将输入输出tensor的信息都进行了封装,并且给他们一个自定义的别名,所以在构建模型的阶段,可以随便给tensor命名,只要在保存训练好的模型的时候,在SignatureDef中给出统一的别名即可。
TensorFlow的关于这部分的例子中用到了不少signature_constants,这些constants的用处主要是提供了一个方便统一的命名。在我们自己理解SignatureDef的作用的时候,可以先不用管这些,遇到需要命名的时候,想怎么写怎么写。
保存
假设定义模型输入的别名为“input_x”,输出的别名为“output” ,使用SignatureDef的代码如下
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
# x 为输入tensor, keep_prob为dropout的prob tensor
inputs = {'input_x': tf.saved_model.utils.build_tensor_info(x),
'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}
# y 为最终需要的输出结果tensor
outputs = {'output' : tf.saved_model.utils.build_tensor_info(y)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
builder.save()
1
2
3
4
5
6
7
8
9
10
11
12
上述inputs增加一个keep_prob是为了说明inputs可以有多个, build_tensor_info方法将tensor相关的信息序列化为TensorInfo protocol buffer。
inputs,outputs都是dict,key是我们约定的输入输出别名,value就是对具体tensor包装得到的TensorInfo。
然后使用build_signature_def方法构建SignatureDef,第三个参数method_name暂时先随便给一个。
创建好的SignatureDef是用在add_meta_graph_and_variables的第三个参数signature_def_map中,但不是直接传入SignatureDef对象。事实上signature_def_map接收的是一个dict,key是我们自己命名的signature名称,value是SignatureDef对象。
载入
载入与使用的代码如下
## 略去构建sess的代码
signature_key = 'test_signature'
input_key = 'input_x'
output_key = 'output'
meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], saved_model_dir)
# 从meta_graph_def中取出SignatureDef对象
signature = meta_graph_def.signature_def
# 从signature中找出具体输入输出的tensor name
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
# 获取tensor 并inference
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
# _x 实际输入待inference的data
sess.run(y, feed_dict={x:_x})
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
从上面两段代码可以知道,我们只需要约定好输入输出的别名,在保存模型的时候使用这些别名创建signature,输入输出tensor的具体名称已经完全隐藏,这就实现创建模型与使用模型的解耦。
---------------------
作者:thriving_fcl
来源:CSDN
原文:https://blog.csdn.net/thriving_fcl/article/details/75213361?utm_source=copy
版权声明:本文为博主原创文章,转载请附上博文链接!
导出pb模型之后测试的python代码的更多相关文章
- ROS系统python代码测试之rostest
ROS系统中提供了测试框架,可以实现python/c++代码的单元测试,python和C++通过不同的方式实现, 之后的两篇文档分别详细介绍各自的实现步骤,以及测试结果和覆盖率的获取. ROS系统中p ...
- 利用Python中的mock库对Python代码进行模拟测试
这篇文章主要介绍了利用Python中的mock库对Python代码进行模拟测试,mock库自从Python3.3依赖成为了Python的内置库,本文也等于介绍了该库的用法,需要的朋友可以参考下 ...
- 【转】利用Python中的mock库对Python代码进行模拟测试
出处 https://www.toptal.com/python/an-introduction-to-mocking-in-python http://www.oschina.net/transla ...
- 隐马尔科夫模型,第三种问题解法,维比特算法(biterbi) algorithm python代码
上篇介绍了隐马尔科夫模型 本文给出关于问题3解决方法,并给出一个例子的python代码 回顾上文,问题3是什么, 下面给出,维比特算法(biterbi) algorithm 下面通过一个具体例子,来说 ...
- 交互模式下测试python代码及变量的四则运算
在交互模式下,python代码可以立即执行,所以这很方便我们进行代码测试 1.命令窗口,输入python (如果没配置环境变量则需带python安装目录的绝对路径) >>> 这个就是 ...
- Python代码缩进与测试模块
一.Python代码缩进 Python 函数没有明显的 begin 和 end ,没有标明函数的开始和结束的花括号.唯一的分隔符是一个冒号 ( : ),接着代码本身是缩进的. 例如:缩进 bui ...
- 让 Python 代码更易维护的七种武器——代码风格(pylint、Flake8、Isort、Autopep8、Yapf、Black)测试覆盖率(Coverage)CI(JK)
让 Python 代码更易维护的七种武器 2018/09/29 · 基础知识 · 武器 原文出处: Jeff Triplett 译文出处:linux中国-Hank Chow 检查你的代码的质 ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- 基于深度学习的人脸性别识别系统(含UI界面,Python代码)
摘要:人脸性别识别是人脸识别领域的一个热门方向,本文详细介绍基于深度学习的人脸性别识别系统,在介绍算法原理的同时,给出Python的实现代码以及PyQt的UI界面.在界面中可以选择人脸图片.视频进行检 ...
随机推荐
- en笔记音标
清辅音和浊辅音区别 开音节和闭音节区别 1 2 3 4 5 6 7 a o e i u w y 开音节 /eɪ/ /əu/ /i:/ /aɪ/ Ju: /aɪ/ 闭音节 /æ/ /ɒ/ /ə/ / ...
- 51nod1513
题解: 更据题意,在树上深度为没一个数的都放在一起,要用的时候二分出来,看结果 用c++的数据结构 代码: #include<bits/stdc++.h> using namespace ...
- Tomcat 域名绑定多个Host配置要点
一.在server.xml中添加Host节点,name就是需要绑定的域名,多个域名在Host节点下建立<Alias></Alias>子节点,可建立多个. <Engine ...
- 安装Windows Installer服务
Windows Installer 5.0.810.500 下载地址: 电信:http://mdl1.mydown.yesky.com/soft/201303/WindowsInstaller.rar ...
- oracle索引原理
B-TREE索引(二叉树索引,默认情况下,我们建的索引都是此种类型) 一个B树索引只有一个根节点,它实际就是位于树的最顶端的分支节点.可以用下图一来描述B树索引的结构.其中,B表示分支节点,而L表示叶 ...
- spring 定时任务参数配置详解
注:本文摘自<Quartz Cron 触发器 Cron Expression 的格式>http://blog.csdn.net/yefengmeander/article/details/ ...
- php include,require 主要是向网页中引入文件
- tp5中捕获异常的配置
首选在配置文件中加入配置如下 // 异常处理handle类 留空使用 \think\exception\Handle 'exception_handle' => '\\app\ ...
- HDU - 6311:Cover(欧拉回路,最少的一笔画覆盖无向图)
The Wall has down and the King in the north has to send his soldiers to sentinel. The North can be r ...
- 配置搭建与使用redis
redis单点.redis主从.redis哨兵 sentinel,redis集群cluster配置搭建与使用 redis是如今被互联网公司使用最广泛的一个中间件,我们打开GitHub搜索redis,边 ...