TensorFlow 调用预训练好的模型—— Python 实现
1. 准备预训练好的模型
- TensorFlow 预训练好的模型被保存为以下四个文件
- data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。
model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"
2. 导入模型图、参数值和相关变量
import tensorflow as tf
import numpy as np
sess = tf.Session()
X = None # input
yhat = None # output
def load_model():
"""
Loading the pre-trained model and parameters.
"""
global X, yhat
modelpath = r'/home/senius/python/c_python/test/'
saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')
saver.restore(sess, tf.train.latest_checkpoint(modelpath))
graph = tf.get_default_graph()
X = graph.get_tensor_by_name("X:0")
yhat = graph.get_tensor_by_name("tanh:0")
print('Successfully load the pre-trained model!')
- 通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。
3. 运行前向传播过程得到预测值
def predict(txtdata):
"""
Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).
Test a single example.
Arg:
txtdata: Array in C.
Returns:
Three coordinates of a face normal.
"""
global X, yhat
data = np.array(txtdata)
data = data.reshape(-1, 41, 41, 41, 3)
output = sess.run(yhat, feed_dict={X: data}) # (-1, 3)
output = output.reshape(-1, 1)
ret = output.tolist()
return ret
- 通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。
4. 测试
load_model()
testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)
testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)
testdata = testdata[0:2, ...] # the first two examples
txtdata = testdata.tolist()
output = predict(txtdata)
print(output)
# [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523],
# [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
- 本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。
获取更多精彩,请关注「seniusen」!
TensorFlow 调用预训练好的模型—— Python 实现的更多相关文章
- tensorflow 使用预训练好的模型的一部分参数
vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...
- 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)
视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等).当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自 ...
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...
- TensorFlow 同时调用多个预训练好的模型
在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测. 调用单个预训练好的模型请点击此处 弄明白了如何调用单个模型,其实调 ...
- 【猫狗数据集】使用预训练的resnet18模型
数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...
- ubuntu16.04 使用tensorflow object detection训练自己的模型
一.构建自己的数据集 1.格式必须为jpg.jpeg或png. 2.在models/research/object_detection文件夹下创建images文件夹,在images文件夹下创建trai ...
- 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征
1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...
- Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning
转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...
随机推荐
- HDU 1220 Cube(数学,找规律)
传送门: http://acm.hdu.edu.cn/showproblem.php?pid=1220 Cube Time Limit: 2000/1000 MS (Java/Others) M ...
- 下载YouTube视频的网站和工具
1.界面友好,可选择的清晰度较多(我个人用这个比较多) http://en.savefrom.net/ 2.几乎可以解析到所有的清晰度 http://www.clipconverter.cc 3.可选 ...
- linux下ssh/sftp配置和权限设置
基于 ssh 的 sftp 服务相比 ftp 有更好的安全性(非明文帐号密码传输)和方便的权限管理(限制用户的活动目录). 1.开通 sftp 帐号,使用户只能 sftp 操作文件, 而不能 ssh ...
- MySql使用入门
SQL是Structure Query Language(结构化查询语言)的缩写. SQL主要可以分为三个类别: 1.DDL(Data Definition Languages)语句:数据定义语言,这 ...
- 关于alert后,才能继续执行后续代码问题
如果在正常情况下,代码要在alert之后才执行,解决办法:将要执行的代码用setTimeout延迟执行即可(原因:页面未加载完毕) 首先,先说明问题情况: 如下JS代码,不能正常执行,只有在最前面加上 ...
- 实践笔记-VA05 销售订单清单 增加字段
现在都自开发很多报表 ,估计没有多少人 用 VA05 1.在结构 VBMTVZ 中增加需要的字段 2.表t180a 中 添加一条 “添加字段”的数据,如下: 3.取值 修改程序 INCLUDE V ...
- 【MYSQL笔记3】MYSQL过程式数据库对象之存储过程的调用、删除和修改
mysql从5.0版本开始支持存储过程.存储函数.触发器和事件功能的实现. 我们以一本书中的例题为例:创建xscj数据库的存储过程,判断两个输入的参数哪个更大.并调用该存储过程. (1)调用 首先,创 ...
- jQuery 切换图片(图标)效果
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- shell 输出带颜色字体
输出特效格式控制:\033[0m 关闭所有属性 \033[1m 设置高亮度 \03[4m 下划线 \033[5m 闪烁 \033[7m 反显 \033[8m 消隐 \ ...
- http状态码(status_codes)
首先:1XX 接受的请求正在处理,2XX请求正常处理完毕,3XX需要进行附加操作以完成请求(重定向?),4XX服务器无法处理请求(也就是客户端请求错误),5XX服务器处理请求出错. 当然不仅仅是一张图 ...