版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/lovelyaiq/article/details/78646401

————————————————

保存模型时,文件格式有两种,ckpt和pb格式,这两种格式的模型区别是什么呢?首先看一下英文的解释。并且我们的学习中也要养成看英文文档的习惯,其一:老外写的东西通俗易懂,其二,在翻译时,每个人的英文理解不同,原汁原味的道理就没有了。

The .ckpt is the model given by tensorflow which includes all the
weights/parameters in the model. The .pb file stores the computational
graph. To make tensorflow work we need both the graph and the
parameters. There are two ways to get the graph:
(1) use the python program that builds it in the first place (tensorflowNetworkFunctions.py).
(2) Use a .pb file (which would have to be generated by tensorflowNetworkFunctions.py).
.ckpt file is were all the intelligence is.

使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。

保存

Tensorflow的模型保存时有几点需要注意:
  1、利用tf.train.write_graph() 默认情况下只导出了网络的定义(没有权重weight)。
  2、利用tf.train.Saver().save() 导出的文件graph_def权重是分离的,就像上述英文的描述。
  我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标

import tensorflow as tf

v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(v1))
    print(sess.run(v2))
    print(sess.run(result))
    saver.save(sess,'model/model.ckpt')

模型保存后,在model目录将会有三个文件。在Tensorflow版本0.11之前,这三个文件为:meta、ckpt、checkpopint,它们保存的内容如下:
    model.ckpt.meta保存计算图的结构,即神经网络的结构
    checkpoint保存一个目录下所有的模型文件列表。
    ckpt 保存程序中每一个变量的取值。

在Tensorflow版本0.11之后,有四个文件分别为:meta、.data、.index、checkpoint。其中.data文件为模型中的训练变量。

模型加载

  模型加载包含两种方式,它们的区分以是否含有计算图上的所有运算。

包含所有运算

import tensorflow as tf

v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
print(sess.run(v1+v2))

这种方法加载模型时和保存模型时的代码基本上是一致的,唯一不同的就是没有变量的初始化过程。

模型加载的时候,如果某个变量没有被加载,则系统将会报错。我们可否使用已经定义好的其它变量来加载呢?当然是可以了,因为Tensorflow是支持的,这需要通过字典的形式来完成,将模型中的变量名重名为我们已经定好的其它变量名。

import tensorflow as tf

x = tf.Variable(tf.constant(1,shape = [1]),name='x')
y = tf.Variable(tf.constant(2,shape = [1]),name='y')
result = x + y # 通过字典将变量重命名
saver = tf.train.Saver(
{'v1':x,'v2':y}) with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt')
out = tf.get_default_graph().get_tensor_by_name('add:0')
print(sess.run(out))

使用变量的滑动平均值的模型保存与加载详见:http://blog.csdn.net/lovelyaiq/article/details/78647850

不包含所有运算 

import tensorflow as tf

saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model/model.ckpt') #获取节点名称
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))

Saver类

  模型的加载与保存都使用到Saver类,该类的初始化参数为:

  def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):

这里面主要用到的参数: 

    max_to_keep:保存checkpoint文件的最大数量,默认值为5.
    keep_checkpoint_every_n_hours:经过多长时间后,只保留一个checkpoint文件,这是方便验证模型训练多长时间后的性能。默认值为10000.0。

而tf.train.save的参数为:

  def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True):

  使用global_stepwrite_meta_graph两个参数可以很好的保存模型。

saver.save(sess, 'my_test_model',global_step=1000)
#保存的文件为:
#my_test_model-1000.index
#my_test_model-1000.meta
#my_test_model-1000.data-00000-of-00001
#checkpoint

模型在保存的时候,计算图在第一次已经保存过,并且随着训练的进行,计算图是不会改变的,因此以后的保存,就可以使用write_meta_graph=True不保存计算图。

saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

tf.train.Saver()默认保存与加载计算图上所有信息。但有时我们只需要保存或加载部分信息。比如在测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播到输出层即可,而不需要类似于变量的初始化、模型保存等辅助节点的信息。而且有时将变量的取值与计算图分开保存是不方便的,因此就需要借助  convert_variables_to_constants  将计算图上所有的变量及其取值通过常量保存,这样整个计算图将会保存到一个文件中。

关于 convert_variables_to_constants 的源码定义如下:从解释中看出,当把网络完全转换为single GraphDef file,它可以删除与加载和保存变量相关的很多操作。

def convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None,variable_names_blacklist=None):
"""Replaces all the variables in a graph with constants of the same values. If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.
import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op) # 导出计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
graph_def = tf.get_default_graph().as_graph_def() # print(graph_def) # 在这里我们只关心"add"节点,因此其它的节点就没有必要导出。
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add']) # 将导出的模型保存到本地
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())

  导出模型的恢复:

import tensorflow as tf
from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 init_op = tf.global_variables_initializer() with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def保存的图加入到当前默认的图
result = tf.import_graph_def(graph_def,return_elements=['add:0'])
print(sess.run(result))

上述方法有一个缺点,那就是我们不能自己定义一个网络输入的placeholder接口,这是不是很蛋筒,不要着急,Tensorflow是可以满足我们的需求。

import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
result = v1 + v2 with tf.variable_scope('foo'):
x = tf.get_variable('x',shape=[1],initializer=tf.constant_initializer(1.0))
y = tf.get_variable('y', shape=[1], initializer=tf.constant_initializer(2.0))
# v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
# v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
input_tensor = tf.placeholder(tf.float32,shape=[1],name='input-x')
new_tensor = tf.placeholder(tf.float32, shape=[1], name='input-y') result = tf.add((x+y),input_tensor,name='sum') data = np.array([15], dtype=np.float32) init_op = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init_op)
# print(sess.run(result,feed_dict={input_tensor:data}))
# print(sess.run(result))
graph_def = tf.get_default_graph().as_graph_def()
# print(graph_def)
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['foo/sum'])
with tf.gfile.GFile('model/combined_model.pb','wb') as f:
f.write(output_graph_def.SerializeToString()) # 模型恢复
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
with tf.gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) # 使用input_map将模型中的placeholder通信映射到重新定义的placeholder。
result1 = tf.import_graph_def(graph_def ,input_map={'foo/input-x:0':new_tensor},return_elements=['foo/sum:0'],name='') # [array([ 18.], dtype=float32)]
print(sess.run(result1,feed_dict={new_tensor:data}))

这种模型恢复的方法在迁移学习中是常用的方法,至于什么是迁移学习,请参考博客:

【转载】 Tensorflow学习笔记-模型保存与加载的更多相关文章

  1. 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)

    文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...

  2. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  3. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  4. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  5. Flutter学习笔记(19)--加载本地图片

    如需转载,请注明出处:Flutter学习笔记(19)--加载本地图片 上一篇博客正好用到了本地的图片,记录一下用法: 首先新建一个文件夹,这个文件夹要跟目录下 然后在pubspec.yaml里面声明出 ...

  6. [置顶] iOS学习笔记47——图片异步加载之EGOImageLoading

    上次在<iOS学习笔记46——图片异步加载之SDWebImage>中介绍过一个开源的图片异步加载库,今天来介绍另外一个功能类似的EGOImageLoading,看名字知道,之前的一篇学习笔 ...

  7. sklearn模型保存与加载

    sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...

  8. Tensorflow学习笔记----模型的保存和读取(4)

    一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...

  9. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  10. tensorflow学习笔记1:导出和加载模型

    用一个非常简单的例子学习导出和加载模型: 导出 写一个y=a*x+b的运算,然后保存graph: import tensorflow as tf from tensorflow.python.fram ...

随机推荐

  1. Linux扩展篇-shell编程(八)-shell字符串截取

    shell字符串截取,一般包含从指定位置和从指定字符截取. 一.从指定位置截取 1) 从字符串左边开始计数 格式: ${string: start :length} 从 string 字符串的左边第 ...

  2. OpenSpeedTest-Server局域网速度测试服务程序

    OpenSpeedTest-Server局域网速度测试服务程序,局域网测速.

  3. 《Android开发卷——自定义日期选择器(一)》

    (小米手机) (中兴手机) 在实际开发中,Google官方提供的时间选择器API已经不能满足于我们的需要了,所以很多公司都是采用自定义的形式来实现日期选择器. 这个例子很简单,定义三个NumberPi ...

  4. elasticSearch RangeQuery范围查询from to的理解

    elasticSearch RangeQuery范围查询from to的理解 Elasticsearch Guide 选择版本号来查询对应的文档内容:https://www.elastic.co/gu ...

  5. Java api zookeeper

    package com.redis.demo.zookeeper; import java.io.Serializable; public class User implements Serializ ...

  6. apollo数据库表查询方法-可以通过批量更新mysql数据库-比如批量更新IP地址等

    select `Id`, `AppId`, `Name` from ApolloPortalDB.App; select `NamespaceId`, `Key`, `Value`, `Comment ...

  7. 制作Jdk镜像

    本文介绍用Dockerfile的方式构建Jdk镜像,请保证安装了Docker环境. 首先创建/opt/jdk目录,后续步骤都在该目录下进行操作. 准备好Jdk安装文件,放到/opt/jdk目录下. 编 ...

  8. .NET 高效灵活的API速率限制解决方案

    前言 FireflySoft.RateLimit是基于.NET Core和.NET Standard构建,支持多种速率限制算法和策略,包括固定窗口.滑动窗口.漏桶.令牌桶等.通过简单的配置和集成,开发 ...

  9. arm linux 移植 iperf3

    背景 新做的硬件需要有进行一些板级接口测试:关于网络的测试很多时候只是停留在 ping 通:能够使用就算了.不知道网络的丢包率,也不知道网络吞吐的性能. 因此,需要使用一些专业化的工具来进行测试:查阅 ...

  10. 【论文阅读】TRO 2021: Fail-Safe Motion Planning for Online Verification of Autonomous Vehicles Using Convex Optimization

    参考与前言 Last edited time: August 3, 2022 10:04 AM Status: Reading Type: TRO Year: 2021 论文链接:https://ie ...