参考学习博客:

# https://www.cnblogs.com/felixwang2/p/9190692.html

一、模型保存
 # https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size =
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, ])
y = tf.placeholder(tf.float32, [None, ]) # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([, ]))
b = tf.Variable(tf.zeros([]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, ), tf.argmax(prediction, )) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
for epoch in range():
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess, 'net/my_net.ckpt')
输出结果:
Iter ,Testing Accuracy 0.8629
Iter ,Testing Accuracy 0.896
Iter ,Testing Accuracy 0.9028
Iter ,Testing Accuracy 0.9052
Iter ,Testing Accuracy 0.9085
Iter ,Testing Accuracy 0.9099
Iter ,Testing Accuracy 0.9122
Iter ,Testing Accuracy 0.9139
Iter ,Testing Accuracy 0.9148
Iter ,Testing Accuracy 0.9163
Iter ,Testing Accuracy 0.9165

二、模型载入
 # https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size = 100
# 计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔值列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
未载入识别率 0.098
载入后识别率 0.9178

程序输出如上结果。

TensorFlow 模型的保存与载入的更多相关文章

  1. TensorFlow——训练模型的保存和载入的方法介绍

    我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...

  2. [翻译] Tensorflow模型的保存与恢复

    翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...

  3. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  4. tensorflow模型的保存与恢复

    1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...

  5. Tensorflow模型变量保存

    Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...

  6. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  7. tensorflow模型的保存与恢复,以及ckpt到pb的转化

    转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...

  8. tensorflow模型的保存与加载

    模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...

  9. TensorFlow(十三):模型的保存与载入

    一:保存 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...

随机推荐

  1. 题解 【洛谷P4995】跳跳!

    一看题目名字,下意识地认为DP. 打开题目,发现是一道水的贪心,和DP没一分钱关系(毕竟是洛谷最水月赛的T2). 废话不多说. 看完题面,首先想到排序.要将乱序的石头高度变为有序,才能更好地想题. C ...

  2. java课后作业3

    1.动手动脑 由于类中定义了需要参数的构造方法,导致系统不再提供默认的构造方法. 2.java字段初始化 运行结果 100 300 java字段在初始化时先按照对应的构造方法执行.若构造方法中没有对变 ...

  3. ASP.NET Razor 语法

    主要的 Razor C# 语法规则 Razor 代码块包含在 @{ ... } 中 内联表达式(变量和函数)以 @ 开头 代码语句用分号结束 变量使用 var 关键字声明 字符串用引号括起来 C# 代 ...

  4. Window逆向基础之逆向工程介绍

    逆向工程 以设计方法学为指导,以现代设计理论.方法.技术为基础,运用各种专业人员的工程设计经验.知识和创新思维,对已有产品进行解剖.深化和再创造. 逆向工程不仅仅在计算机行业.各行各业都存在逆向工程. ...

  5. el-popover 点击input框出现table表,可点击选中,可拼音检索完回车选中

    <template> <card> <el-popover placement="right" width="400" trigg ...

  6. HTML文字标签

    <h1>标题标签,总共六个等级,不能创造标签,只有预定义好的标签才可以被浏览器解析 <br>换行标签,没有内容可以修饰也称为空标签 <p>段落标签</p> ...

  7. 通过java代码HttpRequestUtil(服务器端)发送HTTP请求并解析

    关键代码:String jsonStr = HttpRequestUtil.sendGet(config.getAddress() + config.getPorts() + config.getFi ...

  8. 编写自定义的django-admin命令

    先写标题内容后续补充上 官方文档如下 :http://python.usyiyi.cn/documents/django_182/howto/custom-management-commands.ht ...

  9. GoAhead WebServer 架构

    https://blog.csdn.net/jungsagacity/article/details/7307012

  10. Docker - 查看容器进程在宿主机的 PID

    概述 查看 docker 进程, 在容器外的 pid 背景 docker 中运行的进程, 本质上是运行在 host 上的 这些进程, 在 host 上, 也可以有自己的 pid 如果某种情况下, 连不 ...