TensorFlow 模型的保存与载入
参考学习博客:
# https://www.cnblogs.com/felixwang2/p/9190692.html 一、模型保存
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size =
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, ])
y = tf.placeholder(tf.float32, [None, ]) # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([, ]))
b = tf.Variable(tf.zeros([]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, ), tf.argmax(prediction, )) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
for epoch in range():
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess, 'net/my_net.ckpt')
输出结果:
Iter ,Testing Accuracy 0.8629
Iter ,Testing Accuracy 0.896
Iter ,Testing Accuracy 0.9028
Iter ,Testing Accuracy 0.9052
Iter ,Testing Accuracy 0.9085
Iter ,Testing Accuracy 0.9099
Iter ,Testing Accuracy 0.9122
Iter ,Testing Accuracy 0.9139
Iter ,Testing Accuracy 0.9148
Iter ,Testing Accuracy 0.9163
Iter ,Testing Accuracy 0.9165
二、模型载入
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size = 100
# 计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔值列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
未载入识别率 0.098
载入后识别率 0.9178
程序输出如上结果。
TensorFlow 模型的保存与载入的更多相关文章
- TensorFlow——训练模型的保存和载入的方法介绍
我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...
- [翻译] Tensorflow模型的保存与恢复
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- 三、TensorFlow模型的保存和加载
1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
- Tensorflow模型变量保存
Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...
- 超详细的Tensorflow模型的保存和加载(理论与实战详解)
1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- TensorFlow(十三):模型的保存与载入
一:保存 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...
随机推荐
- K3老单序时簿开发示例
K3需要对老单进行二次开发,老单的二次开发比较麻烦,这里整理一下老单序时簿上添加按钮的二次开发示例. --以下SQL脚本--获取 MENU IDselect FID,FmenuID,FName fro ...
- Learn from Niu 2020.1.21
1. 你一定要看计算机领域的文章. 如果你是看一堆应用,你最终还是会不知道怎么做. 从计算机到energy是降维打击, 当你学习了计算机的hot skill,再去做应用很容易. 2. 搞研究的思路: ...
- TD - 系统异常 - 登录不上
问题描述:登录之后,页面停留在登录页面,没有任何错误提示 解决办法:重新安装插件
- SpringCloud大白话之服务注册中心
SpringCloud-Eureka白话说明 eureka.client.register-with-eureka 属性表示是否将自己注册到Eureka Server, 默认为true. 由于当前应用 ...
- vue mvvm原理与简单实现 -- 上篇
Object.defineProperty介绍-- let obj = {}; Object.defineProperty(obj,'school',{ configurable : true, // ...
- LED Decorative Light Manufacturer - Decorative Lighting: Functionality And Aesthetics
Whether it is for general ambient lighting, task lighting or accent lighting, the decorative lightin ...
- 华为:向充电宝说再见!有它,手机24h不断电
编辑 | 于斌 出品 | 于见(mpyujian) 虽然,美国与中国的谈判还在协商中,不知道最后的消息是好是坏. 但最近华为公司的成绩,却值得让我们为其喝彩和感到骄傲. 据悉,30日,华为在上半年业绩 ...
- shiro登录认证过程讲解
先粘出登录的代码 1. 可以看到已经获取到了username和password ,为了接下来的认证过程,我们需要获取subject对象,也就是代表当前登录用户,并且要将username和passw ...
- 【转载】Mapreduce实现自定义的InputFormat
转自:http://www.cnblogs.com/dlutxm/archive/2011/09/30/2196653.html 在mapreduce程序运行的开始阶段,hadoop需要将待处理的输入 ...
- kafka高吞吐,低延迟的分布式消息队列
核心概念 broker是kafka的节点,多台broker集群就是kafka topic消息分为多个topic partition分区,topic划分了多个partition分区,存在负载均衡策略 每 ...