TF的模型文件
TF的模型文件
标签(空格分隔): TensorFlow
Saver
tensorflow模型保存函数为:
tf.train.Saver()
当然,除了上面最简单的保存方式,也可以指定保存的步数,多长时间保存一次,磁盘上最多保有几个模型(将前面的删除以保持固定个数),如下:
创建saver时指定参数:
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中:
- savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量;如果省略,则保存所有;
- max_to_keep指定磁盘上最多保有几个模型;
- keep_checkpoint_every_n_hours指定多少小时保存一次。
保存模型时指定参数:
saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
如上,其中可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。
示例:
import tensorflow as tf
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver.save(sess,"checkpoint/model.ckpt",global_step=1)
运行后,保存模型保存,得到四个文件:
- checkpoint
- model.ckpt-1.data-00000-of-00001
- model.ckpt-1.index
- model.ckpt-1.meta
checkpoint中记录了已存储(部分)和最近存储的模型:
model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"
...
meta file保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。
index file为一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。
data file保存了模型的所有变量的值,TensorBundle集合。
Restore
Restore模型的过程可以分为两个部分,首先是创建模型,可以手动创建,也可以从meta文件里加载graph进行创建。
模型加载为:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/xx/model.ckpt.meta')
saver.restore(sess, "/xx/model.ckpt")
.meta文件中保存了图的结构信息,因此需要在导入checkpoint之前导入它。否则,程序不知道checkpoint中的变量对应的变量。另外也可以:
# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/xx/model.ckpt")
#saver.restore(sess, tf.train.latest_checkpoint('./'))
PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.
当然,还有一点需要注意,并非所有的TensorFlow模型都能将graph输出到meta文件中或者从meta文件中加载进来,如果模型有部分不能序列化的部分,则此种方法可能会无效。
使用Restore的模型
查看模型的参数
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
tvs = [v for v in tf.trainable_variables()]
for v in tvs:
print(v.name)
print(sess.run(v))
如名所言,以上是查看模型中的trainable variables;或者我们也可以查看模型中的所有tensor或者operations,如下:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
gv = [v for v in tf.global_variables()]
for v in gv:
print(v.name)
上面通过global_variables()获得的与前trainable_variables类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。
下面则可以获得几乎所有的operations相关的tensor:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
ops = [o for o in sess.graph.get_operations()]
for o in ops:
print(o.name)
首先,上面的sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。
此种方法获得的tensor比较齐全,可以从中一窥模型全貌。不过,最方便的方法还是推荐使用tensorboard来查看,当然这需要你提前将sess.graph输出。
直接使用原始模型进行训练或测试
这种操作比较简单,无非是找到原始模型的输入、输出即可。
只要搞清楚输入输出的tensor名字,即可直接使用TensorFlow中graph的get_tensor_by_name函数,建立输入输出的tensor:
with tf.get_default_graph() as graph:
data = graph.get_tensor_by_name('data:0')
output = graph.get_tensor_by_name('output:0')
从模型中找到了输入输出之后,即可直接使用其继续train整个模型,或者将输入数据feed到模型里,并前传得到test输出了。
需要说明的是,有时候从一个graph里找到输入和输出tensor的名字并不容易,所以,在定义graph时,最好能给相应的tensor取上一个明显的名字,比如:
data = tf.placeholder(tf.float32, shape=shape, name='input_data')
preds = tf.nn.softmax(logits, name='output')
诸如此类。这样,就可以直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。
扩展原始模型
除了直接使用原始模型,还可以在原始模型上进行扩展,比如对1中的output继续进行处理,添加新的操作,可以完成对原始模型的扩展,如:
with tf.get_default_graph() as graph:
data = graph.get_tensor_by_name('data:0')
output = graph.get_tensor_by_name('output:0')
logits = tf.nn.softmax(output)
使用原始模型的某部分
有时候,我们有对某模型的一部分进行fine-tune的需求,比如使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可以使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。
with tf.get_default_graph() as graph:
graph.get_tensor_by_name('fc1:0')
fc1 = tf.stop_gradient(fc1)
# add new procedure on fc1
TF的模型文件的更多相关文章
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- TensorFlow 模型文件
在这篇 TensorFlow 教程中,我们将学习如下内容: TensorFlow 模型文件是怎么样的? 如何保存一个 TensorFlow 模型? 如何恢复一个 TensorFlow 模型? 如何使用 ...
- h5模型文件转换成pb模型文件
本文主要记录Keras训练得到的.h5模型文件转换成TensorFlow的.pb文件 #*-coding:utf-8-* """ 将keras的.h5的模型文件,转换 ...
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- 利用tensorboard可视化checkpoint模型文件参数分布
写在前面: 上周微调一个文本检测模型seglink,将特征提取层进行冻结,只训练分类回归层,然而查看tensorboard发现里面有histogram显示模型各个参数分布,看了目前这个训练模型参数分布 ...
- 模型文件(checkpoint)对模型参数的储存与恢复
1. 模型参数的保存: import tensorflow as tfw=tf.Variable(0.0,name='graph_w')ww=tf.Variable(tf.random_normal ...
- tensorflow lite 之生成 tflite 模型文件
下载最新的的tensorflow源码. 1.配置 tflite 文件转换所需环境 安装 bazel 编译工具 https://docs.bazel.build/versions/master/inst ...
- Away3D 学习笔记(一): 加载3DS格式的模型文件
加载外部的3DS文件分为两种: 1: 模型与贴图独立于程序的,也就是从外部的文件夹中读取 private function load3DSFile():Loader3D { loader = new ...
- 转 Django根据现有数据库,自动生成models模型文件
Django引入外部数据库还是比较方便的,步骤如下 : 创建一个项目,修改seting文件,在setting里面设置你要连接的数据库类型和连接名称,地址之类,和创建新项目的时候一致 运行下面代码可以自 ...
随机推荐
- XamarinAndroid组件教程设置动画的设置插值器
XamarinAndroid组件教程设置动画的设置插值器 为动画设置插值器,可以使用BaseItemAnimator抽象类中的SetInterpolator()方法,其语法形式如下: public v ...
- c#取数据库数据 ---两种方法
通常有以下两种方式 SqlDataReader 和SqlDataAdapter|DataSet方式 SqlDataReader 方式使用方式如下: using System; using System ...
- Codeforces.566E.Restoring Map(构造)
题目链接 \(Description\) 对于一棵树,定义某个点的邻居集合为所有距离它不超过\(2\)的点的集合(包括它自己). 给定\(n\)及\(n\)个点的邻居集合,要求构造一棵\(n\)个点的 ...
- Idea创建一个Springboot单模块项目
1.打开IDEA,创建新项目,选择Spring Initializr,选择SDK为你的java版本. 2.点击下一步,输入Artifact 3.点击下一步,选择web 4.finish 5.完成后id ...
- [CC-SEAPERM2]Sereja and Permutations
[CC-SEAPERM2]Sereja and Permutations 题目大意: 有一个\(n(n\le300)\)排列\(p\),将其中一个元素\(p_i\)拿掉,然后将原来大于\(p_i\)的 ...
- 可以直接用的“ html转字符串string”方法
//html转字符串 -(NSString *)filterHTMLString:(NSString *)html { NSScanner * scanner = [NSScanner scanner ...
- Ruby语法基础(一)
Ruby语法基础(一) Ruby是一种开源的面向对象程序设计的服务器端脚本语言,最初由松本行弘(Matz)设计开发,追求『快乐和生产力』,程序员友好型,被称为『human-oriented langu ...
- std::lock_guard/std::unique_lock
C++多线程编程中通常会对共享的数据进行写保护,以防止多线程在对共享数据成员进行读写时造成资源争抢导致程序出现未定义的行为.通常的做法是在修改共享数据成员的时候进行加锁--mutex.在使用锁的时候通 ...
- JS_高程3.基本概念(6)函数
1.ECMAScript中的函数使用function关键字来声明. eg: function sum (num1,num2){ alert(num1+num2); } sum(3,7); 注意: 在有 ...
- python算法练习
6. 约瑟夫环问题:已知n个人(以编号1,2,3...n分别表示)围坐在一张圆桌周围.从编号为k的人开始报数,数到k的那个人被杀掉:他的下一个人又从1开始报数,数到k的那个人又被杀掉:依此规律重复下去 ...