TensorFlow 模型的保存与载入
参考学习博客:
# https://www.cnblogs.com/felixwang2/p/9190692.html 一、模型保存
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size =
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, ])
y = tf.placeholder(tf.float32, [None, ]) # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
W = tf.Variable(tf.zeros([, ]))
b = tf.Variable(tf.zeros([]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y, ), tf.argmax(prediction, )) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
for epoch in range():
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
# 保存模型
saver.save(sess, 'net/my_net.ckpt')
输出结果:
Iter ,Testing Accuracy 0.8629
Iter ,Testing Accuracy 0.896
Iter ,Testing Accuracy 0.9028
Iter ,Testing Accuracy 0.9052
Iter ,Testing Accuracy 0.9085
Iter ,Testing Accuracy 0.9099
Iter ,Testing Accuracy 0.9122
Iter ,Testing Accuracy 0.9139
Iter ,Testing Accuracy 0.9148
Iter ,Testing Accuracy 0.9163
Iter ,Testing Accuracy 0.9165
二、模型载入
# https://www.cnblogs.com/felixwang2/p/9190692.html
# TensorFlow(十三):模型的保存与载入 import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次100张照片
batch_size = 100
# 计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量
init = tf.global_variables_initializer() # 结果存放在一个布尔值列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver() gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
# 未载入模型时的识别率
print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
saver.restore(sess, 'net/my_net.ckpt')
# 载入模型后的识别率
print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
未载入识别率 0.098
载入后识别率 0.9178
程序输出如上结果。
TensorFlow 模型的保存与载入的更多相关文章
- TensorFlow——训练模型的保存和载入的方法介绍
我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入. 1.保存模型 首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起 ...
- [翻译] Tensorflow模型的保存与恢复
翻译自:http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ ...
- 三、TensorFlow模型的保存和加载
1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
- Tensorflow模型变量保存
Tensorflow:模型变量保存 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4.0 pyt ...
- 超详细的Tensorflow模型的保存和加载(理论与实战详解)
1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...
- tensorflow模型的保存与恢复,以及ckpt到pb的转化
转自 https://www.cnblogs.com/zerotoinfinity/p/10242849.html 一.模型的保存 使用tensorflow训练模型的过程中,需要适时对模型进行保存,以 ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- TensorFlow(十三):模型的保存与载入
一:保存 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...
随机推荐
- 地铁合作(hjy)
用时大概时间 18 个小时 我的队友是 :黄珺瑜我们一开始的想法是自己建立数据库,用来查询路线,再调用一下百度api作为地图(仅作地图没有任何操作的那种):后来我们的步骤也的确是这样,我们先确 ...
- C++-指针阅读能力提升
下面的标识符你肯定在工程中看不到,但是在面试题中却非常常见,掌握掌握还是不错的. int (*p1)(int*, int (*f)(int*)); int (*p2[5])( ...
- python collections 模块 常用集合
1.nametulpe # tuple 可以表示不变集合 列如坐标 point = (1,1) # 缺点是 只可以通过下标方式访问 #namedtuple是一个函数,它用来创建一个自定义的tuple对 ...
- js或者jq的string类型或者number类型的相互转换及json对象与字符串的转换
1.将值乘以1,将string类型转为number类型 //算合计价值function summoney(money) { var zijin = $("#main_xm_dam09&quo ...
- 利用Python数据分析基础
一.Numpy的ndarray:一种多维数组 ndarray是一个通用的同构多维数据容器,其所有元素必须是相同的类型.每个数组都有一个shape(一个表示各维度的元组)和dtype(一个用于说明数据数 ...
- (转)json格式转换成javaBean对象的方法
把json格式转换成javaBean才可以.于是查了一下资料,网上最多的资料就是下面的这种方式: Java code? 1 2 3 4 5 6 7 8 9 String str = "[{\ ...
- 题解【SP1716】GSS3 - Can you answer these queries III
题目描述 You are given a sequence \(A\) of \(N (N <= 50000)\) integers between \(-10000\) and \(10000 ...
- b 解题报告
本题已收录至2019/9/15 本周总结 题目 [问题描述] Hja有一棵\(n\)个点的树,树上每个点有点权,每条边有颜色.一条路径的权值是这条路径上所有点的点权和,一条合法的路径需要满足该路径上任 ...
- 磁盘分区(1):fdisk和parted
一.Linux存储管理 关于Linux硬盘的识别: (1)如果是IDE设备,在计算机中将被识别为hd,第一个IDE设备会被识别为hda,第二个IDE设备会被识别为hdb,依此类推. (2)如果是SAT ...
- winform学习(5)MDI窗体
SDI窗体 single document interface 单文档界面,即单个简单的窗体 MDI窗体 multiple document interface 多文档界面(主窗体与子窗体的关系,避免 ...