tf 模型保存
tf用 tf.train.Saver类来实现神经网络模型的保存和读取。无论保存还是读取,都首先要创建saver对象。
用saver对象的save方法保存模型
保存的是所有变量
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
保存模型需要session,初始化变量
用法示例
import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt", global_step=3)
输出

1. global_step 放在文件名后面,起个标记作用
2. save方法输出4个文件
// checkpoint 里面是一堆路径,model_checkpoint_path 记录了最新模型的路径,all_model_checkpoint_paths 记录了之前模型的路径
// model.ckpt-3.data-00000-of-00001 存放的是模型参数
// model.ckpt-3.meta 存放的是计算图
3. 最多只能保存近5次模型,比如我们迭代100次,每次保存一下,最后只留下了最近的5次。
用saver对象的restore方法加载模型
加载的是所有变量,以name为准,假如保存的模型中有变量叫 a ,value是2,那么在加载后,即使重新建立变量a,并赋其他value,其value仍然是2
restore(
sess,
save_path
)
加载模型需要session,不需要初始化变量
用法示例(接前例)
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v2 saver = tf.train.Saver()
#
with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
print(sess.run(result)) # [ 3.]
1. 重新给 name为 v2的变量 赋值,其结果仍然是3,说明加载了之前的v2
2. 新建name为 v22 的变量,报错, 在保存的模型中没找到v2 。说明寻找变量以name为准,不以变量名为准
继续做如下尝试
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v3 saver = tf.train.Saver()
#
with tf.Session() as sess:
# sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
# sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(result)) # [ 3.]
1. 新建name为v22的变量v3,仍然报错,说明新的变量没有被接受
2. 在加载模型前初始化v3,仍然报错,加载模型后初始化v3,仍然报错,这说明在加载的模型中不接受新的变量。
继续尝试
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
# v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2")
v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint
result = v1 + v3 saver = tf.train.Saver()
#
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(v3)) # [7.]
saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./"
sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint
print(sess.run(result)) # [ 3.]
在加载模型前初始化变量,正确输出,但在加载后,报错,证实了我上面的说法,“不接受新的变量”
总结:
1. 模型加载加载的是所有变量,以name为准
2. 模型加载后不接受任何新的变量
3. 在加载模型时需要重新定义计算图上的所有节点,但是变量无需初始化
加载计算图
直接加载计算图就无需重新定义计算图上的节点
用法示例
saver = tf.train.import_meta_graph("Model/model.ckpt-3.meta")
with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3") # 注意路径写法
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [3.]
# print(sess.run(sess.graph.get_tensor_by_name('add:0'))) # [3.]
重命名变量
在加载模型时不接受新的变量,这会造成很多麻烦。
为解决这个问题,加载模型时可以给变量重命名。
用法示例
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2 # 若直接声明Saver类对象,会报错变量找不到
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
saver = tf.train.Saver({"v1": u1, "v2": u2}) with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt-3")
print(sess.run(result)) # [ 3.]
注意重命名格式 老变量的name: 新变量名
参考资料:
https://blog.csdn.net/marsjhao/article/details/72829635
https://blog.csdn.net/shuzfan/article/details/79197432
tf 模型保存的更多相关文章
- TensorFlow:tf.train.Saver()模型保存与恢复
1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...
- tensorflow的tf.train.Saver()模型保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...
- 10 Tensorflow模型保存与读取
我们的模型训练出来想给别人用,或者是我今天训练不完,明天想接着训练,怎么办?这就需要模型的保存与读取.看代码: import tensorflow as tf import numpy as np i ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...
- TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数
模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...
- Sklearn,TensorFlow,keras模型保存与读取
一.sklearn模型保存与读取 1.保存 from sklearn.externals import joblib from sklearn import svm X = [[0, 0], [1, ...
- TensorFlow模型保存和提取方法
一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...
随机推荐
- 宿主iis部署wcf
WCF学习笔记(4)——宿主iis部署wcf 本文将部署一个wcf+silverlight简单实例,以下是详细步骤: (环境:服务端win2003,iis6.0,asp.net4.0:客户端winXP ...
- DedecmsV5.7本地上传缩略图无法自动添加水印的解决方法
问题:dedecms后台 系统->图片水印设置 图片水印设置有开启了,但是本地上传缩略图无法自动添加水印 网上有很多资料,所以记录一下 1.打开dede(实际项目后台文件夹)/archives_ ...
- 『TensorFlow』第九弹_图像预处理_不爱红妆爱武装
部分代码单独测试: 这里实践了图像大小调整的代码,值得注意的是格式问题: 输入输出图像时一定要使用uint8编码, 但是数据处理过程中TF会自动把编码方式调整为float32,所以输入时没问题,输出时 ...
- 别忘了Nologging哦
别忘了Nologging哦
- mysqldump导出报错"mysqldump: Error 2013: Lost connection to MySQL server during query when dumping table `file_storage` at row: 29"
今天mysql备份的crontab自动运行的时候,出现了报警,报警内容如下 mysqldump: Error 2013: Lost connection to MySQL server during ...
- Mybatis中,当插入数据后,返回最新主键id的几种方法,及具体用法
insert元素 属性详解 其属性如下: parameterType ,入参的全限定类名或类型别名 keyColumn ,设置数据表自动生成的主键名.对特定数据库(如PostgreSQL),若自动生成 ...
- asm ftp utilty and usage
Oracle 11g ASM supports ASM FTP, by which operations on ASM files and directories can be performed s ...
- PE文件结构解析
说明:本文件中各种文件头格式截图基本都来自看雪的<加密与解密>:本文相当<加密与解密>的阅读笔记. 1.PE文件总体结构 PE文件框架结构,就是exe文件的排版结构.也就是说我 ...
- 关于TCP长连接和发送心跳的一些理解
原因 TCP是一种有连接的协议,但是这个连接并不是指有一条实际的电路,而是一种虚拟的电路.TCP的建立连接和断开连接都是通过发送数据实现的,也就是我们常说的三次握手.四次挥手.TCP两端保存了一种数据 ...
- from…import 语句