训练代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',') data = origin[:, 0:2]
labels = origin[:, 2] learning_rate = 0.001
training_epochs = 5000
display_step = 1 n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class]) W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b") scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba") cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
saver.save(sess, p.model_path)
print("Optimization Finished!")

推理代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args p = build_parser() with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(p.model_path + ".meta")
new_saver.restore(sess, p.model_path)
pred_proba = tf.get_collection('pred_proba')[0]
graph = tf.get_default_graph()
input_x = graph.get_operation_by_name('input').outputs[0]
r = sess.run(pred_proba, feed_dict={input_x: np.array([[0.6211,5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)

参考资料

TensorFlow 模型保存/载入的两种方法

tensorflow add_to_collection用法的更多相关文章

  1. Tensorflow Summary用法

    本文转载自:https://www.cnblogs.com/lyc-seu/p/8647792.html Tensorflow Summary用法 tensorboard 作为一款可视化神器,是学习t ...

  2. 第一节,TensorFlow基本用法

    一 TensorFlow安装 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tsnsor(张量)意味着N维数组,Flow(流)意味着基 ...

  3. tensorflow SavedModelBuilder用法

    训练代码: # coding: utf-8 from __future__ import print_function from __future__ import division import t ...

  4. tensorflow基本用法个人笔记

    综述   TensorFlow程序分为构建阶段和执行阶段.通过构建一个图.执行这个图来得到结果. 构建图   创建源op,源op不需要任何输入,例如常量constant,源op的输出被传递给其他op做 ...

  5. Tensorflow学习笔记——Summary用法

    tensorboard 作为一款可视化神器,可以说是学习tensorflow时模型训练以及参数可视化的法宝. 而在训练过程中,主要用到了tf.summary()的各类方法,能够保存训练过程以及参数分布 ...

  6. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  7. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  8. 芝麻HTTP:TensorFlow基础入门

    本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式 ...

  9. tensorflow 学习日志

    Windows安装anaconda 和 TensorFlow anaconda : https://zhuanlan.zhihu.com/p/25198543        anaconda 使用与说 ...

随机推荐

  1. 关于GDAL读写Shp乱码的问题总结

    目录 1. 正文 1.1. shp文件本身的编码的问题 1.2. 设置读取的编码方式 1.2.1. GDAL设置 1.2.2. 解码方式 1.2.3. 其他 2. 参考 1. 正文 最近在使用GDAL ...

  2. Oracle基础:数据库操作_数据库事务_表的锁定

    数据库操作语句: INSERT INTO 表名[(字段列表)] VALUES ( 表达式列表); 例子:INSERT INTO emp(empno,ename,job,hiredate) VALUES ...

  3. linux下制作linux系统盘(光盘、U盘)

    cdrecord制作启动光盘 首先cdrecord -scanbus输出设备列表和标识,(我的此次为5,0,0)  [ˈrekərd] 然后用cdrecord -v dev=5,0,0 -eject ...

  4. 成员函数指针,动态绑定(vc平台)

    上一篇介绍了gcc对成员函数指针做了thunk的处理,本篇介绍vc对成员函数指针如何处理,还有动态绑定相关的处理. 同样用回上一篇的例子: struct point {float x,y;}; str ...

  5. head first 设计模式第一章笔记

    设计模式是告诉我们如何组织类和对象以解决某种问题. 学习设计模式,也就是学习其他开发人员的经验与智慧,解决遇到的相同的问题. 使用模式的最好方式是:把模式装进脑子,然后在设计的时候,寻找何处可以使用它 ...

  6. goland学习-go常用命令使用

    goland学习-go常用命令使用 1.跨平台编译:env GOOS=linux GOARCH=amd64 go build 2.获取go第三方包:go get -u github.com/go-sq ...

  7. Alibaba Nacos 学习(五):K8S Nacos搭建,使用nfs

    Alibaba Nacos 学习(一):Nacos介绍与安装 Alibaba Nacos 学习(二):Spring Cloud Nacos Config Alibaba Nacos 学习(三):Spr ...

  8. SCAU-1076 K尾相等数

    代码借鉴SCAU-OJ(感谢!!) 题目:1076 K尾相等数 时间限制:500MS  内存限制:65536K提交次数:251 通过次数:80 题型: 编程题   语言: G++;GCC   Desc ...

  9. 面向对象之classmethod和staticmethod(python内置装饰器)

    对象的绑定方法复习classmethodstaticmethod TOC 对象的绑定方法复习 由对象来调用 会将对象当做第一个参数传入 若对象的绑定方法中还有其他参数,会一并传入 classmetho ...

  10. Anaconda中启动Python时的错误:UnicodeDecodeError: 'gbk' codec can't decode byte 0xaf in position 553

    今天,在Anaconda prompt启动python遇到了如下错误: UnicodeDecodeError: ‘gbk’ codec can’t decode byte 0xaf in positi ...