我们经常遇到训练时间很长,使用起来就是Weight和Bias。那么如何将训练和测试分开操作呢?

TF给出了模型的加载与保存操作,看了网上都是很简单的使用了一下,这里给出一个神经网络的小程序去测试。

本博文使用了Titanic的数据进行操作:

Train.Py

 import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split ################################
# Preparing Data
################################ # read data from file
data = pd.read_csv('data/train.csv') # fill nan values with 0
data = data.fillna(0)
# convert ['male', 'female'] values of Sex to [1, 0]
data['Sex'] = data['Sex'].apply(lambda s: 1 if s == 'male' else 0)
# 'Survived' is the label of one class,
# add 'Deceased' as the other class
data['Deceased'] = data['Survived'].apply(lambda s: 1 - s) # select features and labels for training
dataset_X = data[['Sex', 'Age', 'Pclass', 'SibSp', 'Parch', 'Fare']].as_matrix()
dataset_Y = data[['Deceased', 'Survived']].as_matrix() # split training data and validation set data
X_train, X_val, y_train, y_val = train_test_split(dataset_X, dataset_Y,
test_size=0.2,
random_state=42) ################################
# Constructing Dataflow Graph
################################ # create symbolic variables
X = tf.placeholder(tf.float32, shape=[None, 6])
y = tf.placeholder(tf.float32, shape=[None, 2]) # weights and bias are the variables to be trained
weights = tf.Variable(tf.random_normal([6, 2]), name='weights')
bias = tf.Variable(tf.zeros([2]), name='bias')
y_pred = tf.nn.softmax(tf.matmul(X, weights) + bias) # Minimise cost using cross entropy
# NOTE: add a epsilon(1e-10) when calculate log(y_pred),
# otherwise the result will be -inf
cross_entropy = - tf.reduce_sum(y * tf.log(y_pred + 1e-10),
reduction_indices=1)
cost = tf.reduce_mean(cross_entropy) # use gradient descent optimizer to minimize cost
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost) # calculate accuracy
correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1))
acc_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) ################################
# Training and Evaluating the model
################################
saver = tf.train.Saver()
# use session to run the calculation
with tf.Session() as sess:
# variables have to be initialized at the first place
tf.global_variables_initializer().run()
# training loop
for epoch in range(10):
total_loss = 0.
for i in range(len(X_train)):
# prepare feed data and run
feed_dict = {X: [X_train[i]], y: [y_train[i]]}
_, loss = sess.run([train_op, cost], feed_dict=feed_dict)
total_loss += loss
# display loss per epoch
print('Epoch: %04d, total loss=%.9f' % (epoch + 1, total_loss))
saver_path = saver.save(sess,"wjy_data/model.ckpt")
# Accuracy calculated by TensorFlow
accuracy = sess.run(acc_op, feed_dict={X: X_val, y: y_val})
print("Accuracy on validation set: %.9f" % accuracy) # Accuracy calculated by NumPy
pred = sess.run(y_pred, feed_dict={X: X_val})
correct = np.equal(np.argmax(pred, 1), np.argmax(y_val, 1))
numpy_accuracy = np.mean(correct.astype(np.float32))
print("Accuracy on validation set (numpy): %.9f" % numpy_accuracy) # predict on test data
testdata = pd.read_csv('data/test.csv')
testdata = testdata.fillna(0)
# convert ['male', 'female'] values of Sex to [1, 0]
testdata['Sex'] = testdata['Sex'].apply(lambda s: 1 if s == 'male' else 0)
X_test = testdata[['Sex', 'Age', 'Pclass', 'SibSp', 'Parch', 'Fare']]
predictions = np.argmax(sess.run(y_pred, feed_dict={X: X_test}), 1)
submission = pd.DataFrame({
"PassengerId": testdata["PassengerId"],
"Survived": predictions
}) submission.to_csv("titanic-submission.csv", index=False)

注意:

  saver_path = saver.save(sess,"wjy_data/model.ckpt")
  项目目录下面必须新建一个wjy_data的文件夹,不然会报错!!!

Test.Py

 import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split # create symbolic variables
X = tf.placeholder(tf.float32, shape=[None, 6])
y = tf.placeholder(tf.float32, shape=[None, 2]) # weights and bias are the variables to be trained
weights = tf.Variable(tf.random_normal([6, 2]), name='weights')
bias = tf.Variable(tf.zeros([2]), name='bias')
y_pred = tf.nn.softmax(tf.matmul(X, weights) + bias) # predict on test data
testdata = pd.read_csv('data/test.csv')
testdata = testdata.fillna(0)
# convert ['male', 'female'] values of Sex to [1, 0]
testdata['Sex'] = testdata['Sex'].apply(lambda s: 1 if s == 'male' else 0)
X_test = testdata[['Sex', 'Age', 'Pclass', 'SibSp', 'Parch', 'Fare']]
################################
# Training and Evaluating the model
################################
saver = tf.train.Saver()
# use session to run the calculation
with tf.Session() as sess:
# variables have to be initialized at the first place
tf.global_variables_initializer().run()
#save_path = saver.save(sess,"Saved_model/model.ckpt")
saver.restore(sess,"wjy_data/model.ckpt")#加载模型
predictions = np.argmax(sess.run(y_pred, feed_dict={X: X_test}), 1)
submission = pd.DataFrame({
"PassengerId": testdata["PassengerId"],
"Survived": predictions
})
#saver = tf.train.Saver()
submission.to_csv("titanic-submission.csv", index=False)

很方便的使用保存模型的方式去测试和训练数据,不然怎么办~~

参考:

  《深度学习原理与TensorFlow实战》

  https://blog.csdn.net/lujiandong1/article/details/53301994

TensorFlow模型加载与保存的更多相关文章

  1. Tensorflow模型加载与保存、Tensorboard简单使用

    先上代码: from __future__ import absolute_import from __future__ import division from __future__ import ...

  2. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  3. tensorflow 模型加载(没有checkpoint文件或者说只加载其中一个模型)

    1.如果有checkpoint文件的话,加载模型很简单: 第一步:都是加载图: with tf.Session() as sess: saver=tf.train.import_meta_graph( ...

  4. Tensorflow同时加载使用多个模型

    在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Sessi ...

  5. KnockoutJS 3.X API 第七章 其他技术(1) 加载和保存JSON数据

    Knockout允许您实现复杂的客户端交互性,但几乎所有Web应用程序还需要与服务器交换数据,或至少将本地存储的数据序列化. 最方便的交换或存储数据的方式是JSON格式 - 大多数Ajax应用程序今天 ...

  6. DirectX11 With Windows SDK--19 模型加载:obj格式的读取及使用二进制文件提升读取效率

    前言 一个模型通常是由三个部分组成:网格.纹理.材质.在一开始的时候,我们是通过Geometry类来生成简单几何体的网格.但现在我们需要寻找合适的方式去表述一个复杂的网格,而且包含网格的文件类型多种多 ...

  7. OpenGL OBJ模型加载.

    在我们前面绘制一个屋,我们可以看到,需要每个立方体一个一个的自己来推并且还要处理位置信息.代码量大并且要时间.现在我们通过加载模型文件的方法来生成模型文件,比较流行的3D模型文件有OBJ,FBX,da ...

  8. 6_1 持久化模型与再次加载_探讨(1)_三种持久化模型加载方式以及import_meta_graph方式加载持久化模型会存在的变量管理命名混淆的问题

    笔者提交到gitHub上的问题描述地址是:https://github.com/tensorflow/tensorflow/issues/20140 三种持久化模型加载方式的一个小结论 加载持久化模型 ...

  9. 从零开始openGL——三、模型加载及鼠标交互实现

    前言 在上篇文章中,介绍了基本图形的绘制.这篇博客中将介绍模型的加载.绘制以及鼠标交互的实现. 模型加载 模型存储 要实现模型的读取.绘制,我们首先需要知道模型是如何存储在文件中的. 通常模型是由网格 ...

随机推荐

  1. Python中的yield生成器的简单介绍

    Python yield 使用浅析(整理自:廖 雪峰, 软件工程师, HP 2012 年 11 月 22 日 ) 初学 Python 的开发者经常会发现很多 Python 函数中用到了 yield 关 ...

  2. Python-接口自动化(一)

    python基础知识(一) 一.python语言特点 1.易于学习:python有相对较少的关键字,结构简单,有一个明确定义的语法,学起来比较简单: 2.易于阅读:python代码定义的更清晰: 3. ...

  3. kafka 分区和副本以及kafaka 执行流程,以及消息的高可用

    1.Kafka概览 Apache下的项目Kafka(卡夫卡)是一个分布式流处理平台,它的流行是因为卡夫卡系统的设计和操作简单,能充分利用磁盘的顺序读写特性.kafka每秒钟能有百万条消息的吞吐量,因此 ...

  4. [Ajax] 如何使用Ajax传递多个复选框的值

    HTML+JavaScript代码: <!DOCTYPE html> <html> <head> <meta charset="UTF-8" ...

  5. RSA加密解密实现(JAVA)

    1.关于RSA算法的原理解析参考:http://www.ruanyifeng.com/blog/2013/06/rsa_algorithm_part_one.html 2.RSA密钥长度.明文长度和密 ...

  6. 判断是不是微信浏览器和QQ内置浏览器

    is_weixn() { let ua = navigator.userAgent.toLowerCase(); if (ua.match(/MicroMessenger/i) == "mi ...

  7. webapp js与安卓,ios怎么交互

    ) } } export default { callhandler (name, data, callback) { setupWebViewJavascriptBridge(function (b ...

  8. java基础(1)IntelliJ IDEA入门和数组操作 解决idea启动速度慢--配置JVM

    一. IntelliJ IDEA入门 1 快捷键和技巧 智能补全代码,比如只写首字母按回车: psvm+Enter :public stactic void main(String[] args) s ...

  9. 【软件安装与环境配置】TX2刷机过程

    前言 使用TX2板子之前需要进行刷机,一般都是按照官网教程的步骤刷机,无奈买不起的宝宝只有TX2核心板,其他外设自己搭建,所以只能重新制作镜像,使用该镜像进行刷机. 系统需求 1.Host Platf ...

  10. 文笔很差系列1 - 也谈谈AlphaGo

    距离AlphaGo击败李世石已经过去数月了,心中的震撼至今犹在,全刊报道此项比赛的<围棋天地>杂志我已经看了不下十遍.总也想说点自己的意见,却也不知道从哪里说起,更不知道想表达些什么. 作 ...