【转载】 Tensorflow学习笔记-模型保存与加载
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/lovelyaiq/article/details/78646401
————————————————
保存模型时,文件格式有两种,ckpt和pb格式,这两种格式的模型区别是什么呢?首先看一下英文的解释。并且我们的学习中也要养成看英文文档的习惯,其一:老外写的东西通俗易懂,其二,在翻译时,每个人的英文理解不同,原汁原味的道理就没有了。
The .ckpt is the model given by tensorflow which includes all the
weights/parameters in the model. The .pb file stores the computational
graph. To make tensorflow work we need both the graph and the
parameters. There are two ways to get the graph:
(1) use the python program that builds it in the first place (tensorflowNetworkFunctions.py).
(2) Use a .pb file (which would have to be generated by tensorflowNetworkFunctions.py).
.ckpt file is were all the intelligence is.
使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。
保存
Tensorflow的模型保存时有几点需要注意:
1、利用tf.train.write_graph() 默认情况下只导出了网络的定义(没有权重weight)。
2、利用tf.train.Saver().save() 导出的文件graph_def与权重是分离的,就像上述英文的描述。
我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标
import tensorflow as tf
v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(v1))
print(sess.run(v2))
print(sess.run(result))
saver.save(sess,'model/model.ckpt')
模型保存后,在model目录将会有三个文件。在Tensorflow版本0.11之前,这三个文件为:meta、ckpt、checkpopint,它们保存的内容如下:
model.ckpt.meta保存计算图的结构,即神经网络的结构
checkpoint保存一个目录下所有的模型文件列表。
ckpt 保存程序中每一个变量的取值。
在Tensorflow版本0.11之后,有四个文件分别为:meta、.data、.index、checkpoint。其中.data文件为模型中的训练变量。
模型加载
模型加载包含两种方式,它们的区分以是否含有计算图上的所有运算。
包含所有运算
import tensorflow as tf v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
print(sess.run(v1+v2))
这种方法加载模型时和保存模型时的代码基本上是一致的,唯一不同的就是没有变量的初始化过程。
模型加载的时候,如果某个变量没有被加载,则系统将会报错。我们可否使用已经定义好的其它变量来加载呢?当然是可以了,因为Tensorflow是支持的,这需要通过字典的形式来完成,将模型中的变量名重名为我们已经定好的其它变量名。
import tensorflow as tf x = tf.Variable(tf.constant(1,shape = [1]),name='x')
y = tf.Variable(tf.constant(2,shape = [1]),name='y')
result = x + y # 通过字典将变量重命名
saver = tf.train.Saver(
{'v1':x,'v2':y}) with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
out = tf.get_default_graph().get_tensor_by_name('add:0')
print(sess.run(out))
使用变量的滑动平均值的模型保存与加载详见:http://blog.csdn.net/lovelyaiq/article/details/78647850
不包含所有运算
import tensorflow as tf
saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
#获取节点名称
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))
Saver类
模型的加载与保存都使用到Saver类,该类的初始化参数为:
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
这里面主要用到的参数:
max_to_keep:保存checkpoint文件的最大数量,默认值为5.
keep_checkpoint_every_n_hours:经过多长时间后,只保留一个checkpoint文件,这是方便验证模型训练多长时间后的性能。默认值为10000.0。
而tf.train.save的参数为:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True):
使用global_step和write_meta_graph两个参数可以很好的保存模型。
saver.save(sess, 'my_test_model',global_step=1000)
#保存的文件为:
#my_test_model-1000.index
#my_test_model-1000.meta
#my_test_model-1000.data-00000-of-00001
#checkpoint
模型在保存的时候,计算图在第一次已经保存过,并且随着训练的进行,计算图是不会改变的,因此以后的保存,就可以使用write_meta_graph=True不保存计算图。
saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)
tf.train.Saver()默认保存与加载计算图上所有信息。但有时我们只需要保存或加载部分信息。比如在测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播到输出层即可,而不需要类似于变量的初始化、模型保存等辅助节点的信息。而且有时将变量的取值与计算图分开保存是不方便的,因此就需要借助 convert_variables_to_constants 将计算图上所有的变量及其取值通过常量保存,这样整个计算图将会保存到一个文件中。
关于 convert_variables_to_constants 的源码定义如下:从解释中看出,当把网络完全转换为single GraphDef file,它可以删除与加载和保存变量相关的很多操作。
def convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None,variable_names_blacklist=None):
"""Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op) # 导出计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_def = tf.get_default_graph().as_graph_def() # print(graph_def) # 在这里我们只关心"add"节点,因此其它的节点就没有必要导出。
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add']) # 将导出的模型保存到本地
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())
导出模型的恢复:
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def保存的图加入到当前默认的图
result = tf.import_graph_def(graph_def,return_elements=['add:0'])
print(sess.run(result))
上述方法有一个缺点,那就是我们不能自己定义一个网络输入的placeholder接口,这是不是很蛋筒,不要着急,Tensorflow是可以满足我们的需求。
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 with tf.variable_scope('foo'):
x = tf.get_variable('x',shape=[1],initializer=tf.constant_initializer(1.0))
y = tf.get_variable('y', shape=[1], initializer=tf.constant_initializer(2.0))
# v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
# v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
input_tensor = tf.placeholder(tf.float32,shape=[1],name='input-x')
new_tensor = tf.placeholder(tf.float32, shape=[1], name='input-y') result = tf.add((x+y),input_tensor,name='sum') data = np.array([15], dtype=np.float32) init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# print(sess.run(result,feed_dict={input_tensor:data}))
# print(sess.run(result))
graph_def = tf.get_default_graph().as_graph_def()
# print(graph_def)
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['foo/sum'])
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString()) # 模型恢复
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) # 使用input_map将模型中的placeholder通信映射到重新定义的placeholder。
result1 = tf.import_graph_def(graph_def ,input_map={'foo/input-x:0':new_tensor},return_elements=['foo/sum:0'],name='') # [array([ 18.], dtype=float32)]
print(sess.run(result1,feed_dict={new_tensor:data}))
这种模型恢复的方法在迁移学习中是常用的方法,至于什么是迁移学习,请参考博客:
【转载】 Tensorflow学习笔记-模型保存与加载的更多相关文章
- 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)
文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
- Flutter学习笔记(19)--加载本地图片
如需转载,请注明出处:Flutter学习笔记(19)--加载本地图片 上一篇博客正好用到了本地的图片,记录一下用法: 首先新建一个文件夹,这个文件夹要跟目录下 然后在pubspec.yaml里面声明出 ...
- [置顶] iOS学习笔记47——图片异步加载之EGOImageLoading
上次在<iOS学习笔记46——图片异步加载之SDWebImage>中介绍过一个开源的图片异步加载库,今天来介绍另外一个功能类似的EGOImageLoading,看名字知道,之前的一篇学习笔 ...
- sklearn模型保存与加载
sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...
- Tensorflow学习笔记----模型的保存和读取(4)
一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- tensorflow学习笔记1:导出和加载模型
用一个非常简单的例子学习导出和加载模型: 导出 写一个y=a*x+b的运算,然后保存graph: import tensorflow as tf from tensorflow.python.fram ...
随机推荐
- tomcat部署Jenkins
安装环境 jdk 1.8 tomcat 9.0 jenkins 2.290 准备工作 安装好Tomcat,8080端口启动 安装好jdk,配置好环境变量 ECS服务器安全组放开8080端口 关闭防火墙 ...
- java: 找不到符号 java: Compilation failed: internal java compiler error
java: 找不到符号 java: Compilation failed: internal java compiler error 1.File---->Setting------>ja ...
- 简单易懂的JSON框架
分享一个由本人编写的JSON框架. JSON反序列化使用递归方式来解析JSON字符串,不使用任何第三方JAR包,只使用JAVA的反射来创建对象(必须要有无参构造器),赋值,编写反射缓存来提升性 ...
- Css var 简述
Css var 语法 var(custom-property-name, value) - custom-property-name 必须 变量必须以 --开头 后面可以是英文.数字连接符,区分大小写 ...
- CNN -- Simple Residual Network
Smiling & Weeping ---- 我爱你,从这里一直到月亮,再绕回来 说明: 1.要解决的问题:梯度消失 2. 跳连接,H(x) = F(x)+x,张量维度必须一致,加完后再激活. ...
- 02-CentOS7基础
基础知识介绍 shell shell俗称壳,它包裹在内核的外面,是用户命令的翻译官. 作用:接收用户的命令,翻译后(处理一下)交给Linux内核处理. 命令 -> shell -> 内核 ...
- MoneyPrinterPlus:AI自动短视频生成工具-微软云配置详解
MoneyPrinterPlus可以使用大模型自动生成短视频,我们可以借助Azure提供的语音服务来实现语音合成和语音识别的功能. Azure的语音服务应该是我用过的效果最好的服务了,微软还得是微软. ...
- hynitron ts 驱动分析
# hynitron ts 驱动分析 背景 在公司项目中搞LCD移植的时候,在TP功能上,有时候频繁操作屏幕时会导致i2c总线返回-2错误. 问题描述: 1.安卓桌面起来以后,点击屏幕有响应. 2.此 ...
- 【论文阅读】VDBFusion: Flexible and Efficient TSDF Integration of Range Sensor Data
Type: Sensors Year: 2022 tag: Mapping 组织: Bonn 参考与前言 论文链接:https://www.ncbi.nlm.nih.gov/pmc/articles/ ...
- 深度解读昇腾CANN多流并行技术,提高硬件资源利用率
本文分享自华为云社区<深度解读昇腾CANN多流并行技术,提高硬件资源利用率>,作者:昇腾CANN. 随着人工智能应用日益成熟,文本.图片.音频.视频等非结构化数据的处理需求呈指数级增长,数 ...