【转载】 Tensorflow学习笔记-模型保存与加载
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/lovelyaiq/article/details/78646401
————————————————
保存模型时,文件格式有两种,ckpt和pb格式,这两种格式的模型区别是什么呢?首先看一下英文的解释。并且我们的学习中也要养成看英文文档的习惯,其一:老外写的东西通俗易懂,其二,在翻译时,每个人的英文理解不同,原汁原味的道理就没有了。
The .ckpt is the model given by tensorflow which includes all the
weights/parameters in the model. The .pb file stores the computational
graph. To make tensorflow work we need both the graph and the
parameters. There are two ways to get the graph:
(1) use the python program that builds it in the first place (tensorflowNetworkFunctions.py).
(2) Use a .pb file (which would have to be generated by tensorflowNetworkFunctions.py).
.ckpt file is were all the intelligence is.
使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。
保存
Tensorflow的模型保存时有几点需要注意:
1、利用tf.train.write_graph() 默认情况下只导出了网络的定义(没有权重weight)。
2、利用tf.train.Saver().save() 导出的文件graph_def与权重是分离的,就像上述英文的描述。
我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标
import tensorflow as tf
v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(v1))
print(sess.run(v2))
print(sess.run(result))
saver.save(sess,'model/model.ckpt')
模型保存后,在model目录将会有三个文件。在Tensorflow版本0.11之前,这三个文件为:meta、ckpt、checkpopint,它们保存的内容如下:
model.ckpt.meta保存计算图的结构,即神经网络的结构
checkpoint保存一个目录下所有的模型文件列表。
ckpt 保存程序中每一个变量的取值。
在Tensorflow版本0.11之后,有四个文件分别为:meta、.data、.index、checkpoint。其中.data文件为模型中的训练变量。
模型加载
模型加载包含两种方式,它们的区分以是否含有计算图上的所有运算。
包含所有运算
import tensorflow as tf v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
print(sess.run(v1+v2))
这种方法加载模型时和保存模型时的代码基本上是一致的,唯一不同的就是没有变量的初始化过程。
模型加载的时候,如果某个变量没有被加载,则系统将会报错。我们可否使用已经定义好的其它变量来加载呢?当然是可以了,因为Tensorflow是支持的,这需要通过字典的形式来完成,将模型中的变量名重名为我们已经定好的其它变量名。
import tensorflow as tf x = tf.Variable(tf.constant(1,shape = [1]),name='x')
y = tf.Variable(tf.constant(2,shape = [1]),name='y')
result = x + y # 通过字典将变量重命名
saver = tf.train.Saver(
{'v1':x,'v2':y}) with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
out = tf.get_default_graph().get_tensor_by_name('add:0')
print(sess.run(out))
使用变量的滑动平均值的模型保存与加载详见:http://blog.csdn.net/lovelyaiq/article/details/78647850
不包含所有运算
import tensorflow as tf saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt') #获取节点名称
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))
Saver类
模型的加载与保存都使用到Saver类,该类的初始化参数为:
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
这里面主要用到的参数:
max_to_keep:保存checkpoint文件的最大数量,默认值为5.
keep_checkpoint_every_n_hours:经过多长时间后,只保留一个checkpoint文件,这是方便验证模型训练多长时间后的性能。默认值为10000.0。
而tf.train.save的参数为:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True):
使用global_step和write_meta_graph两个参数可以很好的保存模型。
saver.save(sess, 'my_test_model',global_step=1000)
#保存的文件为:
#my_test_model-1000.index
#my_test_model-1000.meta
#my_test_model-1000.data-00000-of-00001
#checkpoint
模型在保存的时候,计算图在第一次已经保存过,并且随着训练的进行,计算图是不会改变的,因此以后的保存,就可以使用write_meta_graph=True不保存计算图。
saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)
tf.train.Saver()默认保存与加载计算图上所有信息。但有时我们只需要保存或加载部分信息。比如在测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播到输出层即可,而不需要类似于变量的初始化、模型保存等辅助节点的信息。而且有时将变量的取值与计算图分开保存是不方便的,因此就需要借助 convert_variables_to_constants 将计算图上所有的变量及其取值通过常量保存,这样整个计算图将会保存到一个文件中。
关于 convert_variables_to_constants 的源码定义如下:从解释中看出,当把网络完全转换为single GraphDef file,它可以删除与加载和保存变量相关的很多操作。
def convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None,variable_names_blacklist=None):
"""Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op) # 导出计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_def = tf.get_default_graph().as_graph_def() # print(graph_def) # 在这里我们只关心"add"节点,因此其它的节点就没有必要导出。
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add']) # 将导出的模型保存到本地
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())
导出模型的恢复:
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def保存的图加入到当前默认的图
result = tf.import_graph_def(graph_def,return_elements=['add:0'])
print(sess.run(result))
上述方法有一个缺点,那就是我们不能自己定义一个网络输入的placeholder接口,这是不是很蛋筒,不要着急,Tensorflow是可以满足我们的需求。
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 with tf.variable_scope('foo'):
x = tf.get_variable('x',shape=[1],initializer=tf.constant_initializer(1.0))
y = tf.get_variable('y', shape=[1], initializer=tf.constant_initializer(2.0))
# v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
# v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
input_tensor = tf.placeholder(tf.float32,shape=[1],name='input-x')
new_tensor = tf.placeholder(tf.float32, shape=[1], name='input-y') result = tf.add((x+y),input_tensor,name='sum') data = np.array([15], dtype=np.float32) init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# print(sess.run(result,feed_dict={input_tensor:data}))
# print(sess.run(result))
graph_def = tf.get_default_graph().as_graph_def()
# print(graph_def)
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['foo/sum'])
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString()) # 模型恢复
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) # 使用input_map将模型中的placeholder通信映射到重新定义的placeholder。
result1 = tf.import_graph_def(graph_def ,input_map={'foo/input-x:0':new_tensor},return_elements=['foo/sum:0'],name='') # [array([ 18.], dtype=float32)]
print(sess.run(result1,feed_dict={new_tensor:data}))
这种模型恢复的方法在迁移学习中是常用的方法,至于什么是迁移学习,请参考博客:
【转载】 Tensorflow学习笔记-模型保存与加载的更多相关文章
- 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)
文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
- Flutter学习笔记(19)--加载本地图片
如需转载,请注明出处:Flutter学习笔记(19)--加载本地图片 上一篇博客正好用到了本地的图片,记录一下用法: 首先新建一个文件夹,这个文件夹要跟目录下 然后在pubspec.yaml里面声明出 ...
- [置顶] iOS学习笔记47——图片异步加载之EGOImageLoading
上次在<iOS学习笔记46——图片异步加载之SDWebImage>中介绍过一个开源的图片异步加载库,今天来介绍另外一个功能类似的EGOImageLoading,看名字知道,之前的一篇学习笔 ...
- sklearn模型保存与加载
sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...
- Tensorflow学习笔记----模型的保存和读取(4)
一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- tensorflow学习笔记1:导出和加载模型
用一个非常简单的例子学习导出和加载模型: 导出 写一个y=a*x+b的运算,然后保存graph: import tensorflow as tf from tensorflow.python.fram ...
随机推荐
- kettle从入门到精通 第二十三课 kettle carte 错误(java.lang.OutOfMemoryError: GC overhead limit exceeded,Could not emit buffer due to lack of requests,java heap space)分析
1.Could not emit buffer due to lack of requests(无法发出缓冲区,因为请求不足.) 原因有两点:1)消费者处理数据能力较弱,如表输出步骤.2)消费者没有处 ...
- docker on windows v19 红色启动不了
遇到: error during connect: Get http://%2F%2F.%2Fpipe%2Fdocker_engine/v1.40/containers/json: open //./ ...
- 项目管理--PMBOK 读书笔记(9)【项目资源管理】
1.团队成员的角色与职责: 1)层级结构(OBS):与 WBS 交叉确认部门的全部项目指责,项目组织结构图: 2)矩阵结构(RAM):工作包(活动)与项目团队的关系,主要用于明确角色与期望(职责) 3 ...
- MapStruct - 注解汇总
@Mapper @Mapper 将接口或抽象类标记为映射器,并自动生成映射实现类代码. public @interface Mapper { // 引入其他其他映射器 Class<?>[] ...
- 学习ThreeJS
创建第一个应用 使用Three JS进行编程的时候,都是在调用new Three().XXX 来实现方法,让我们先根据官方文档创建一个demo https://threejs.org/docs/ind ...
- Asp.net core Swashbuckle Swagger 的常用配置
背景 .net core Swashbuckle Swagger 官方文档:https://github.com/domaindrivendev/Swashbuckle.AspNetCore 我们发现 ...
- CentOS7安装最新版ruby
背景 直接通过yum安装的ruby版本太低,不能满足redis.fpm等软件的需求. 系统环境 CentOS7 安装步骤 下载ruby http://www.ruby-lang.org/en/down ...
- Markdown常用语法详解
背景知识 什么是html html是一种网页标记语言.我们平常见到的那么好看的网页就是通过html语言来编写的. html语言举例: <h1>hello world</h1> ...
- Oracle自动化编译无效对象
问题描述:使用存储过程的方式对oracle数据库的无效对象,如视图或者同义词进行定期的编译,让他变成一个有效的对象,加上定时任务可以实现自动化的处理.同时在数据库内部创建一个记录表,用来记录被编译过的 ...
- ARM平台实现Docker容器技术
什么是Docker? (1)Docker的架构 Docker是一个开源的应用容器引擎,让开发者可打包他们的应用以及依赖包到一个可移植的镜像中,然后发布到任何流行的Linux或Windows机器上, ...