TensorFlow 模型的保存与载入
参考学习博客:
# https://www.cnblogs.com/felixwang2/p/9190692.html 一、模型保存
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size =
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, ])
y = tf.placeholder(tf.float32, [None, ]) # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([, ]))
b = tf.Variable(tf.zeros([]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, ), tf.argmax(prediction, )) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
for epoch in range():
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess, 'net/my_net.ckpt')
输出结果:
Iter ,Testing Accuracy 0.8629
Iter ,Testing Accuracy 0.896
Iter ,Testing Accuracy 0.9028
Iter ,Testing Accuracy 0.9052
Iter ,Testing Accuracy 0.9085
Iter ,Testing Accuracy 0.9099
Iter ,Testing Accuracy 0.9122
Iter ,Testing Accuracy 0.9139
Iter ,Testing Accuracy 0.9148
Iter ,Testing Accuracy 0.9163
Iter ,Testing Accuracy 0.9165
二、模型载入
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size = 100
# 计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔值列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
未载入识别率 0.098
载入后识别率 0.9178
程序输出如上结果。
TensorFlow 模型的保存与载入的更多相关文章
- TensorFlow——训练模型的保存和载入的方法介绍
我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...
- [翻译] Tensorflow模型的保存与恢复
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- 三、TensorFlow模型的保存和加载
1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
- Tensorflow模型变量保存
Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...
- 超详细的Tensorflow模型的保存和加载(理论与实战详解)
1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- TensorFlow(十三):模型的保存与载入
一:保存 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...
随机推荐
- Centos7部署jenkins
1. 下载rpm包: a) 下载地址:https://pkg.jenkins.io/redhat-stable/ b) 点选一个下载即可,例如点选:“jen ...
- [SDOI2014] 重建 - 矩阵树定理,概率期望
#include <bits/stdc++.h> #define eps 1e-6 using namespace std; const int N = 55; namespace mat ...
- K-means VS K-NN and 手肘法
1. The difference between classification and clustering. from here. Classification: supervised learn ...
- Shell的 for 循环小例子
<1> 上例子 for i in f1 f2 f3; do @echo $i; done 执行结果: f1 f2 f3 但是,请注意:如果是在makefile 中写,要写成这个样子: al ...
- shiro中setUnauthorizedUrl("/403")不起作用
最近学习shiro框架,在用户没有权限的情况下想让其跳转到403页面,结果非自己预想的效果.后来找到一个解决办法如下: 转载来源 SpringBoot中集成Shiro的时候, 配置setUnautho ...
- 关于Win32串口
因为近段时间接触Hid相对来说多一些,由此忽略了串口中获取cbInQue这个重要的东西,下面是错误代码 // Win32SerialPortLib.cpp : 定义 DLL 应用程序的导出函数. // ...
- ieee-explore文献免费下载办法
假设文献网址为:http://ieeexplore.ieee.org/document/xxxxxxx/ 下载保存的话,改为http://ieeexplore.ieee.org.sci-hub.tw/ ...
- 1167E - Range Deleting 双指针
题意:给出n个数的序列,并给出x,这n个数的范围为[1,x],f(L,R)表示删除序列中取值为[l,r]的数,问有几对L,R使得操作后的序列为非递减序列 思路:若[l,r]成立,那么[l,r+1],. ...
- Customized Mini LED Keychain For Better Brand Identity
Looking for products that tell people the brand name? Then you'll find an affordable product that wi ...
- java项目连接Oracle配置文件
转载自:https://blog.csdn.net/shijing266/article/details/42527471 driverClassName=oracle.jdbc.driver.Ora ...