在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测。

弄明白了如何调用单个模型,其实调用多个模型也就顺理成章。我们只需要建立多个图,然后每个图导入一个模型,再针对每个图创建一个会话,分别进行预测即可。

import tensorflow as tf
import numpy as np # 建立两个 graph
g1 = tf.Graph()
g2 = tf.Graph() # 为每个 graph 建创建一个 session
sess1 = tf.Session(graph=g1)
sess2 = tf.Session(graph=g2) X_1 = None
tst_1 = None
yhat_1 = None X_2 = None
tst_2 = None
yhat_2 = None def load_model(sess):
"""
Loading the pre-trained model and parameters.
"""
global X_1, tst_1, yhat_1
with sess1.as_default():
with sess1.graph.as_default():
modelpath = r'F:/resnet/model/new0.25-0.35/'
saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
saver.restore(sess1, tf.train.latest_checkpoint(modelpath))
graph = tf.get_default_graph()
X_1 = graph.get_tensor_by_name("X:0")
tst_1 = graph.get_tensor_by_name("tst:0")
yhat_1 = graph.get_tensor_by_name("tanh:0")
print('Successfully load the model_1!') def load_model_2():
"""
Loading the pre-trained model and parameters.
"""
global X_2, tst_2, yhat_2
with sess2.as_default():
with sess2.graph.as_default():
modelpath = r'F:/resnet/model/new0.25-0.352/'
saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
saver.restore(sess2, tf.train.latest_checkpoint(modelpath))
graph = tf.get_default_graph()
X_2 = graph.get_tensor_by_name("X:0")
tst_2 = graph.get_tensor_by_name("tst:0")
yhat_2 = graph.get_tensor_by_name("tanh:0")
print('Successfully load the model_2!') def test_1(txtdata):
"""
Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
Test a single axample.
Arg:
txtdata: Array in C.
Returns:
The normal of a face.
"""
global X_1, tst_1, yhat_1
data = np.array(txtdata)
data = data.reshape(-1, 41, 41, 41, 3)
output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True}) # (100, 3)
output = output.reshape(-1, 1)
ret = output.tolist()
return ret def test_2(txtdata):
"""
Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
Test a single axample.
Arg:
txtdata: Array in C.
Returns:
The normal of a face.
"""
global X_2, tst_2, yhat_2 data = np.array(txtdata)
data = data.reshape(-1, 41, 41, 41, 3)
output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True}) # (100, 3)
output = output.reshape(-1, 1)
ret = output.tolist() return ret

最后,本程序只是为了说明问题,抛砖引玉,代码有很多冗余之处,不要模仿!

获取更多精彩,请关注「seniusen」!

TensorFlow 同时调用多个预训练好的模型的更多相关文章

  1. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  2. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  3. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  4. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  5. tensorflow如何正确加载预训练词向量

    使用预训练词向量和随机初始化词向量的差异还是挺大的,现在说一说我使用预训练词向量的流程. 一.构建本语料的词汇表,作为我的基础词汇 二.遍历该词汇表,从预训练词向量中提取出该词对应的词向量 三.初始化 ...

  6. tensorflow 使用预训练好的模型的一部分参数

    vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...

  7. 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征

    1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...

  8. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  9. 【译】深度双向Transformer预训练【BERT第一作者分享】

    目录 NLP中的预训练 语境表示 语境表示相关研究 存在的问题 BERT的解决方案 任务一:Masked LM 任务二:预测下一句 BERT 输入表示 模型结构--Transformer编码器 Tra ...

随机推荐

  1. 轻量ORM-SqlRepoEx (四)INSERT、UPDATE、DELETE 语句

    *本文中所用类声明见上一篇博文<轻量ORM-SqlRepoEx (三)Select语句>中Customers类 一.增加记录 1.工厂一个实例仓储 var repository = Rep ...

  2. Linux CentOS7下安装Zookeeper-3.4.10服务(最新)

    Linux CentOS7下安装Zookeeper-3.4.10服务(最新) 2017年10月27日 01:25:26 极速-蜗牛 阅读数:1933   版权声明:本文为博主原创文章,未经博主允许不得 ...

  3. OSI七层模型详解(物理层、数据链路层、网络层、传输层.....应用层协议与硬件)

    原文链接 https://blog.csdn.net/xw20084898/article/details/39438783

  4. cornerstone提交报错"but is missing"以及xocde提示"missing from working copy"

    问题描述 xocde提示"missing from working copy" 虽然这种警告不会影响程序到运行,但是数量很多,而且在svn提交的时候回出现这种问题 使用的svn工具 ...

  5. c# 分析SQL语句中的表操作

    最近写了很多方向的总结和demo.基本包含了工作中的很多方面,毕竟c#已经高度封装并且提供了很多类库.前面已经总结了博文.最近2天突然感觉前面的SQL分析阻组件的确麻烦,也注意看了下.为了方便大家学习 ...

  6. 【PTA 天梯赛训练】修理牧场(哈夫曼树+优先队列)

    农夫要修理牧场的一段栅栏,他测量了栅栏,发现需要N块木头,每块木头长度为整数L​i​​个长度单位,于是他购买了一条很长的.能锯成N块的木头,即该木头的长度是L​i​​的总和. 但是农夫自己没有锯子,请 ...

  7. MySQL备份恢复之mysqldump

      Preface       The day before yesterday,there's a motif about the lock procedure when backing up My ...

  8. PHP学习day1

    PHP 变量规则: 变量以 $ 符号开头,其后是变量的名称 变量名称必须以字母或下划线开头 变量名称不能以数字开头 变量名称只能包含字母数字字符和下划线(A-z.0-9 以及 _) 变量名称对大小写敏 ...

  9. XPath Helper的安装使用

    XPath Helper的安装使用 xpath helper 是一款chrome浏览器插件,主要用来分析当前网页信息的xpath,在抓取数据时一般会使用到xpath. 安装 下载地址:http://c ...

  10. redis之cluster(集群)

    搭建redis cluster 1. 准备节点 2. 节点间的通信 3. 分配槽位给节点 redis-cluster架构 多个服务端,负责读写,彼此通信,redis指定了16384个槽. 多匹马儿,负 ...