TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧。

TensorFlow 提供了 Saver 类,可以进行保存和恢复。下面是 TensorFlow-Examples 项目中提供的保存和恢复代码。

'''
Save and Restore a model using TensorFlow.
This example is using the MNIST database of handwritten digits
(http://yann.lecun.com/exdb/mnist/) Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
''' from __future__ import print_function # Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf # Parameters
learning_rate = 0.001
batch_size = 100
display_step = 1
model_path = "/tmp/model.ckpt" # Network Parameters
n_hidden_1 = 256 # 1st layer number of features
n_hidden_2 = 256 # 2nd layer number of features
n_input = 784 # MNIST data input (img shape: 28*28)
n_classes = 10 # MNIST total classes (0-9 digits) # tf Graph input
x = tf.placeholder("float", [None, n_input])
y = tf.placeholder("float", [None, n_classes]) # Create model
def multilayer_perceptron(x, weights, biases):
# Hidden layer with RELU activation
layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
layer_1 = tf.nn.relu(layer_1)
# Hidden layer with RELU activation
layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
layer_2 = tf.nn.relu(layer_2)
# Output layer with linear activation
out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
return out_layer # Store layers weight & bias
weights = {
'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
} # Construct model
pred = multilayer_perceptron(x, weights, biases) # Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Initializing the variables
init = tf.global_variables_initializer() # 'Saver' op to save and restore all the variables
saver = tf.train.Saver() # Running first session
print("Starting 1st session...")
with tf.Session() as sess:
# Initialize variables
sess.run(init) # Training cycle
for epoch in range(3):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# Loop over all batches
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
y: batch_y})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1), "cost=", \
"{:.9f}".format(avg_cost))
print("First Optimization Finished!") # Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) # Save model weights to disk
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path) # Running a new session
print("Starting 2nd session...")
with tf.Session() as sess:
# Initialize variables
sess.run(init) # Restore model weights from previously saved model
saver.restore(sess, model_path)
print("Model restored from file: %s" % save_path) # Resume training
for epoch in range(7):
avg_cost = 0.
total_batch = int(mnist.train.num_examples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
y: batch_y})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", \
"{:.9f}".format(avg_cost))
print("Second Optimization Finished!") # Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy:", accuracy.eval(
{x: mnist.test.images, y: mnist.test.labels}))

原文链接:http://www.tensorflownews.com/

TensorFlow 训练好模型参数的保存和恢复代码的更多相关文章

  1. 模型文件(checkpoint)对模型参数的储存与恢复

    1.  模型参数的保存: import tensorflow as tfw=tf.Variable(0.0,name='graph_w')ww=tf.Variable(tf.random_normal ...

  2. Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几

    首先是模型参数和网络结构的保存 #coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  3. tensorflow训练线性回归模型

    tensorflow安装 tensorflow安装过程不是很顺利,在这里记录一下 环境:Ubuntu 安装 sudo pip install tensorflow 如果出现错误 Could not f ...

  4. TensorFlow基础笔记(14) 网络模型的保存与恢复_mnist数据实例

    http://blog.csdn.net/huachao1001/article/details/78502910 http://blog.csdn.net/u014432647/article/de ...

  5. java web应用调用python深度学习训练的模型

    之前参见了中国软件杯大赛,在大赛中用到了深度学习的相关算法,也训练了一些简单的模型.项目线上平台是用java编写的web应用程序,而深度学习使用的是python语言,这就涉及到了在java代码中调用p ...

  6. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  7. TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数

    模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...

  8. tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  9. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

随机推荐

  1. 一起了解 .Net Foundation 项目 No.8

    .Net 基金会中包含有很多优秀的项目,今天就和笔者一起了解一下其中的一些优秀作品吧. 中文介绍 中文介绍内容翻译自英文介绍,主要采用意译.如与原文存在出入,请以原文为准. IdentityModel ...

  2. IP 数据报

    IP 数据报 1.IP 数据报的格式 一个 IP 数据报由首部和数据两部分组成.(数据报也可以说是数据包) 首部的前一部分是固定长度,共 20 字节,是所有 IP 数据报必须具有的. 在首部的固定部分 ...

  3. Tornado 简述

    前言 python 旗下,群英荟萃,豪杰并起.单是用于 web 开发的,就有 webpy.web2py.bottle.pyramid.zope2.flask.tornado.django 等等,不一而 ...

  4. 非对称加密算法RSA 学习

    非对称加密算法RSA 学习 RSA加密算法是一种非对称加密算法.RSA是1977年由罗纳德·李维斯特(Ron Rivest).阿迪·萨莫尔(Adi Shamir)和伦纳德·阿德曼(Leonard Ad ...

  5. celery异步任务框架

    目录 Celery 一.官方 二.Celery异步任务框架 Celery架构图 消息中间件 任务执行单元 任务结果存储 三.使用场景 四.Celery的安装配置 五.两种celery任务结构:提倡用包 ...

  6. web前端性能优化一

    作为一个前端会允许自己的作品,在非硬性条件下出现卡顿? 肯定是不会,所以给大家梳理一下前端性能的优化. 一:文件操作 html文件压缩: 删除无用的空格和换行符等其他无意义字符 css文件压缩: 删除 ...

  7. Swift --闭包表达式与闭包(汇编分析)

    在Swift中,可以通过func定义一个函数,也可以通过闭包表达式定义一个函数! 一.闭包表达式 概念 闭包表达式与定义函数的语法相对比,有区别如下: 去除了func 去除函数名 返回值类型添加了关键 ...

  8. 集成google翻译的小tips

    文章首发于github.io 2018-08-04 12:43:20 google翻译的强大,就像我们公司的slogan : "让语言无国界,让世人心相通" 友情提示: googl ...

  9. python数据转换

    主要内容 1:数字类型:算术运算 bool:判断真假,运用场景在逻辑运算里较多,比如while循环了. 字符串:可以索引取值,可以嵌套 列表:存放任意数据类型,因为是按序存放的,故可以索引取值, 字典 ...

  10. Netty学习(4):NIO网络编程

    概述 在 Netty学习(3)中,我们已经学习了 Buffer 和 Channel 的概念, 接下来就让我们通过实现一个 NIO 的多人聊天服务器来深入理解 NIO 的第 3个组件:Selector. ...