训练代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',') data = origin[:, 0:2]
labels = origin[:, 2] learning_rate = 0.001
training_epochs = 5000
display_step = 1 n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class]) W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b") scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba") cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
saver.save(sess, p.model_path)
print("Optimization Finished!")

推理代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args p = build_parser() with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(p.model_path + ".meta")
new_saver.restore(sess, p.model_path)
pred_proba = tf.get_collection('pred_proba')[0]
graph = tf.get_default_graph()
input_x = graph.get_operation_by_name('input').outputs[0]
r = sess.run(pred_proba, feed_dict={input_x: np.array([[0.6211,5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)

参考资料

TensorFlow 模型保存/载入的两种方法

tensorflow add_to_collection用法的更多相关文章

  1. Tensorflow Summary用法

    本文转载自:https://www.cnblogs.com/lyc-seu/p/8647792.html Tensorflow Summary用法 tensorboard 作为一款可视化神器,是学习t ...

  2. 第一节,TensorFlow基本用法

    一 TensorFlow安装 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tsnsor(张量)意味着N维数组,Flow(流)意味着基 ...

  3. tensorflow SavedModelBuilder用法

    训练代码: # coding: utf-8 from __future__ import print_function from __future__ import division import t ...

  4. tensorflow基本用法个人笔记

    综述   TensorFlow程序分为构建阶段和执行阶段.通过构建一个图.执行这个图来得到结果. 构建图   创建源op,源op不需要任何输入,例如常量constant,源op的输出被传递给其他op做 ...

  5. Tensorflow学习笔记——Summary用法

    tensorboard 作为一款可视化神器,可以说是学习tensorflow时模型训练以及参数可视化的法宝. 而在训练过程中,主要用到了tf.summary()的各类方法,能够保存训练过程以及参数分布 ...

  6. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  7. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  8. 芝麻HTTP:TensorFlow基础入门

    本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式 ...

  9. tensorflow 学习日志

    Windows安装anaconda 和 TensorFlow anaconda : https://zhuanlan.zhihu.com/p/25198543        anaconda 使用与说 ...

随机推荐

  1. 在oracle数据库中创建DBLink

    涉及到两个数据库之间的访问时,可以创建datebase link来互相访问. ’创建方法: 1.通过PL/SQL客户端,找到datebase link,右键新建 输入相应信息 2.直接用命令行创建 一 ...

  2. ansible-template

    template简介 template功能: 根据模板文件动态生成对应的配置文件 template文件必须存放于templates目录下,且命名为 .j2 结尾 ansible的template模板使 ...

  3. Python3安装mysql模块

    pip3 install mysql 1.错误1 原因:在 Python 3.x 版本后,ConfigParser.py 已经更名为 configparser.py 所以出错! 解决,将模块cp一份为 ...

  4. Grid表格的js触发事件

    没怎么接触过Grid插件: 解决的问题是:点击Grid表行里的内容触发js方法弹出模态框,用以显示选中内容的详细信息. 思路:给准备要触发的列加上一个css属性,通过这个css属性来获取元素并触发js ...

  5. ubuntu server 1604 搭建FTP服务器

    1.查看是否安装 ftp服务器vsftpd -v 2.安装ftp服务器sudo apt-get install vsftpd 3.如果安装失败或者配置出现问题,可以卸载 ftp服务器sudo apt- ...

  6. 减少HTTP请求的方式

    1. 图片地图 缺点:坐标难定义:除了矩形之外几乎无法定义其他形状:通过DHTML(动态DOM操作)创建的图片地图在 IE 不兼容 <img usemap="#map1" b ...

  7. vue动态样式设置

    思路: 通过 v-bind:class="true ? style1 : style2 " 配合三元表达式完成样式的切换 具体实现 //return设置控制的参数 //有多个需要样 ...

  8. 看了这篇Redis,我以大专生的身份,进入了阿里,定级P7

    摘要: 前几天讲了Redis的面试知识点,当然那只是一部分,我相信各位在面试,或者实际开发过程中对缓存雪崩,穿透,击穿也不陌生吧,就算没遇到过但是你肯定听过,那三者到底有什么区别,我们又应该怎么去防止 ...

  9. python线程条件变量Condition(31)

    对于线程与线程之间的交互我们在前面的文章已经介绍了 python 互斥锁Lock / python事件Event , 今天继续介绍一种线程交互方式 – 线程条件变量Condition. 一.线程条件变 ...

  10. 使用 Topshelf 组件一步一步创建 Windows 服务 (2) 使用Quartz.net 调度

    上一篇说了如何使用 Topshelf 组件快速创建Windows服务,接下来介绍如何使用 Quartz.net 关于Quartz.net的好处,网上搜索都是一大把一大把的,我就不再多介绍. 先介绍需要 ...