tensorflow SavedModelBuilder用法
训练代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',')
data = origin[:, 0:2]
labels = origin[:, 2]
learning_rate = 0.001
training_epochs = 5000
display_step = 1
n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class])
W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b")
scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
builder = tf.saved_model.builder.SavedModelBuilder(p.model_dir)
inputs = {'input': tf.saved_model.utils.build_tensor_info(x)}
outputs = {'pred_proba': tf.saved_model.utils.build_tensor_info(pred_proba)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature': signature})
builder.save()
推理代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
with tf.Session() as sess:
signature_key = 'test_signature'
input_key = 'input'
output_key = 'pred_proba'
meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], p.model_dir)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
r = sess.run(y, feed_dict={x: np.array([[0.6211, 5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)
参考资料
tensorflow SavedModelBuilder用法的更多相关文章
- Tensorflow Summary用法
本文转载自:https://www.cnblogs.com/lyc-seu/p/8647792.html Tensorflow Summary用法 tensorboard 作为一款可视化神器,是学习t ...
- 第一节,TensorFlow基本用法
一 TensorFlow安装 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tsnsor(张量)意味着N维数组,Flow(流)意味着基 ...
- tensorflow add_to_collection用法
训练代码: # coding: utf-8 from __future__ import print_function from __future__ import division import t ...
- tensorflow基本用法个人笔记
综述 TensorFlow程序分为构建阶段和执行阶段.通过构建一个图.执行这个图来得到结果. 构建图 创建源op,源op不需要任何输入,例如常量constant,源op的输出被传递给其他op做 ...
- Tensorflow学习笔记——Summary用法
tensorboard 作为一款可视化神器,可以说是学习tensorflow时模型训练以及参数可视化的法宝. 而在训练过程中,主要用到了tf.summary()的各类方法,能够保存训练过程以及参数分布 ...
- (转)TensorFlow 入门
TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
- 芝麻HTTP:TensorFlow基础入门
本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式 ...
- tensorflow 学习日志
Windows安装anaconda 和 TensorFlow anaconda : https://zhuanlan.zhihu.com/p/25198543 anaconda 使用与说 ...
随机推荐
- jenkins里的定时构建
1. 定时构建语法:* * * * * (五颗星,多个时间点,中间用逗号隔开)第一个*表示分钟,取值0~59第二个*表示小时,取值0~23第三个*表示一个月的第几天,取值1~31第四个*表示第几月,取 ...
- 50行Python代码实现视频中物体颜色识别和跟踪(必须以红色为例)
目前计算机视觉(CV)与自然语言处理(NLP)及语音识别并列为人工智能三大热点方向,而计算机视觉中的对象检测(objectdetection)应用非常广泛,比如自动驾驶.视频监控.工业质检.医疗诊断等 ...
- logback日志回顾整理--2018年8月8日
几年前使用过logback作为项目的日志框架. 当时觉得这个框架比log4j更加好用. 所以系统的学习了一遍. 后来换了公司, 不再使用logback. 如今, 又有机会使用logback了, 所以, ...
- [ch01-03]神经网络基本原理
系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力. 前言 For things I don't know h ...
- ACE框架 基于共享内存的分配器 (算法设计)
继承上一篇<ACE框架 基于共享内存的分配器设计>,本篇分析算法部分的设计. ACE_Malloc_T模板定义了这样一个分配器组件 分配器组件聚合了三个功能组件:同步组件ACE_LOCK, ...
- 从cocos2dx源代码看android和iOS跨平台那些事
cocos2dx一个跨移动(平板)平台的游戏引擎,支持2d和3d,基于c/c++,网上介绍多在此不详叙.我们本篇关心的是跨平台那些事,自然而然就找到platform目录.好家伙,支持的操作平台还真不少 ...
- 【Oracle】Oracle ASM管理监控命令
目录 Oracle ASM管理监控命令 目的: 1.查看磁盘组 2.查看目前归档 3.查看ASM的磁盘路径 4. asmcmd Oracle ASM管理监控命令 目的: 查看目前Oracle ASM相 ...
- js 关于apply和call的理解使用
关于call和apply,以前也思考良久,很多时候都以为记住了,但是,我太难了.今天我特地写下笔记,希望可以完全掌握这个东西,也希望可以帮助到任何想对学习这个东西的同学. 一.apply函数定义与理解 ...
- k8s 获取 Pod ip 添加到环境变量
0x00 事件 有一个需要将 Pod 自身的 ip 地址添加到环境变量的需求,可以在 yaml 文件的 env 中这样设置: env: - name: POD_OWN_IP_ADDRESS value ...
- 使用laravel快速构建vuepress管理器
使用laravel快速构建vuepress管理器 介绍 刚刚学了下laravel感觉很方便,最近也在用vuepress做个人博客,但是感觉每次写文章管理文章不是特别方便,就使用laravel写了这个v ...