一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)
该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196
我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大的帮助,因此我决定把它记录下来。
原文链接A quick complete tutorial to save and restore Tensorflow models–by ANKIT SACHAN
(英文水平有限,有翻译不当的地方请见谅)
在本教程中,我将介绍:
- tensorflow模型是什么样子的?
- 如何保存一个Tensorflow模型?
- 如何恢复一个Tensorflow模型用于预测/迁移学习?
- 如何导入预训练的模型进行微调和修改?
本教程假设你已经对训练一个神经网络有一定了解。否则请先看这篇教程Tensorflow Tutorial 2: image classifier using convolutional neural network再看本教程。
什么是Tensorflow模型?
当你训练好一个神经网络后,你会想保存好你的模型便于以后使用并且用于生产。因此,什么是Tensorflow模型?Tensorflow模型主要包含网络设计(或者网络图)和训练好的网络参数的值。所以Tensorflow模型有两个主要的文件:
a) Meta图:
Meta图是一个协议缓冲区(protocol buffer),它保存了完整的Tensorflow图;比如所有的变量、运算、集合等。这个文件的扩展名是.meta。
b) Checkpoint 文件
这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt。但是, 从0.11版本开始,Tensorflow对改文件做了点修改,checkpoint文件不再是单个.ckpt文件,而是如下两个文件:
mymodel.data-00000-of-00001
mymodel.index
其中, .data文件包含了我们的训练变量。除此之外,还有一个叫checkpoint的文件,它保留了最新的checkpoint文件的记录。
对于0.11之后的版本,其包含四个文件:
model.ckpt.meta
model.ckpt.index
checkpoint
model.ckpt.data-00000-of-00001
现在我们已经知道Tensorflow模型是什么样子的,让我们继续学习如何保存模型。
保存Tensorflow模型
假如你正在训练一个用于图像分类的卷积神经网络(training a convolutional neural network for image classification)。通常你会先观察损失和准确率,一旦发现网络收敛,就可以手动停止训练过程或者直接训练固定迭代次数。当训练完成后,我们想要保存所有的变量和网络图便于以后使用。因此在Tensorflow中, 为了保存网络图和所有参数的值,我们应该创建tf.train.Saver()这个类的一个对象。
saver = tf.train.Saver()
记住Tensorflow变量只有在会话(session)中才能激活。因此,你需要在会话中调用你刚创建的对象的保存方法。
saver.save(sess, "my-test-model")
这里,sess是一个session对象,“my-test-model”是你的模型名字。让我们看一个完整的例子:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model') # This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint
如果我们要在1000次迭代后保存模型,我们应该在调用保存方法时传入步数计数:
saver.save(sess, "my_test_model", global_step=1000)
这会在模型名称后加一个“-1000”并且会创建如下文件:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
假设在训练过程中,我们要每1000次迭代保存我们的模型,因此.meta文件会在第一次(1000次迭代)时创建,我们并不需要之后每1000次迭代都保存一遍这个文件(我们在2000,3000…迭代时都不需要保存这个文件,因为这个文件始终不变)。我们只需要保存这个模型供以后使用,因为模型图不会变化。所以,当我们不想重写meta图的时候,我们这样写:
saver.save(sess, "my-model", global_step=step, write_meta_graph=False)
如果你只想保留4个最新的模型并且在训练过程中每过2小时保存一次模型,你可以使用max_to_keep和keep_checkpoint_every_n_hours,就像这样:
#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
注意,如果我们在tf.train.Saver()中不指定任何东西,它将保存所有的变量。要是我们不想保存所有的变量而只是一部分变量。我们可以指定我们想要保存的变量/集合。当创建tf.train.Saver()对象的时候,我们给它传递一个我们想要保存的变量的字典列表。我们来看一个例子:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
当需要的时候,这个代码可以用来保存Tensorflow图中的指定部分。
导入预训练模型
如果你想要用其他人预训练的模型进行微调,需要做两件事:
a) 创建网络
你可以写python代码来手动创建和原来一样的模型。但是,想想看,我们已经将原始网络保存在了.meta文件中,可以用tf.train.import()函数来重建网络:
saver = tf.train.import_meta_graph("my_test_model-1000.meta")
记住,import_meta_graph函数将只将定义在.meta文件中的网络添加到当前的图上。因此,它虽然帮你创建了额图/网络,但我们还是需要导入我们在这个图上训练好的模型的参数。
b) 导入参数
我们可以调用由tf.train.Saver()创建的对象saver中的restore方法来恢复网络中的参数。
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
这样,张量的值(如w1和w2)就被恢复并且可以访问了:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('my-model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.
现在你已经理解了如何保存和导入Tensorflow模型。在下一节,我会介绍一个实际应用即导入任何预训练好的模型。
使用恢复的模型
现在你已经理解如何保存和恢复Tensorflow模型,我们来写一个实际的示例来恢复任何预训练的模型并用它来预测、微调或者进一步训练。无论你什么时候用Tensorflow,你都会定义一个网络,它有一些样本(训练数据)和超参数(如学习率、迭代次数等)。通常用一个占位符(placeholder)来将所有的训练数据和超参数输入给网络。下面我们用占位符建立一个小型网络并保存它。注意,当网络被保存的时候,占位符中的值并没有被保存。
import tensorflow as tf #Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8} #Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables
saver = tf.train.Saver() #Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
当我们想要恢复这个网络的时候,我们不仅需要恢复图和权重,还需要准备一个新的feed_dict来将训练数据输入到网络中。我们可以通过graph.get_tensor_by_name方法来引用这些保存的运算和占位符变量。
#How to access saved variable/Tensor/placeholders
w1 = graph.get_tensor_by_name("w1:0") ## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
如果我们只是想用不同的数据运行相同的网络,你可以方便地用feed_dict将新的数据送到网络中。
import tensorflow as tf sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./')) # Now, let's access and create placeholders variables and
# create feed-dict to feed new data graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
要是你想在原来的计算图中通过添加更多的层来增加更多的运算并且训练。当然也可以实现,如下:
import tensorflow as tf sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./')) # Now, let's access and create placeholders variables and
# create feed-dict to feed new data graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") #Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2) print sess.run(add_on_op,feed_dict)
#This will print 120.
但是,我们能够只恢复原来图中的一部分然后添加一些其它层来微调吗?当然可以,只要通过graph.get_tensor_by_name()方法来获取原网络的部分计算图并在上面继续建立新计算图。这里给出了一个实际的例子。我们用meta图导入了一个预训练的vgg网络,然后将最后一层的输出个数改成2用于微调新的数据。
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning #Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0') #use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list() num_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output) # Now, you run this with fine-tuning data in sess.run()
希望本文能够让你清楚地理解Tensorflow是如何被保存和微调的。请在评论区自由分享你的问题或者疑问。
另外,为了便于理解,我上传了一份用MNIST数据集训练及调用模型的例子,见链接:https://pan.baidu.com/s/1C-l3YZGbEsAFIClgSQN46Q 密码:3iq8
一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)的更多相关文章
- TensorFlow使用记录 (九): 模型保存与恢复
模型文件 tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值. .meta .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息 ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...
- TensorFlow进阶(六)---模型保存与恢复、自定义命令行参数
模型保存与恢复.自定义命令行参数. 在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用.模型的保存和恢复也是通过tf.train.Saver类去实现,它主 ...
- TensorFlow模型保存和提取方法
一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...
- TensorFlow 模型保存/载入
我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.jobl ...
- Tensorflow模型保存与加载
在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...
- 10 Tensorflow模型保存与读取
我们的模型训练出来想给别人用,或者是我今天训练不完,明天想接着训练,怎么办?这就需要模型的保存与读取.看代码: import tensorflow as tf import numpy as np i ...
- TensorFlow:tf.train.Saver()模型保存与恢复
1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...
随机推荐
- tcl脚本直接执行脚本中中文识别不了的处理
上一篇说tcl中文乱码是因为我写了个bat调用该脚本,但是脚本中的中文路径是乱码.今天刚好有时间进行解决下: 首先看看调用代码 "./bin/base-tcl8.6-thread-win32 ...
- mpvue 解析
前言 mpvue是一款使用Vue.js开发微信小程序的前端框架. 总结 生命周期的理解 文档 一次前后端实践 使用此框架,开发者将得到完整的 Vue.js 开发体验,同时为H5和小程序提供了代码复 ...
- MySQL常用内置函数
本篇博客源自以下博客地址: http://www.mamicode.com/info-detail-250393.html
- 阿里云centos7使用yum安装mysql的正确姿势
yum快速安装mysql 新增yum源 rpm -Uvh http://dev.mysql.com/get/mysql-community-release-el7-5.noarch.rpm 查看可用的 ...
- css3中trastion,transform,animation基本的了解
毕业答辩一耽误就是一个月的时间,感觉自己浪费好多时间,而且学习劲头都没有以前的好,学习是个漫长艰苦的事情,也出现了好多问题,希望自己有则改之,无则加冕,曾国藩曾说过:悔者,所以守其缺而禾取求全也.虽然 ...
- c# ef
找出不同项 ).ToList(); resultMsg = string.Join(",", query.select(p=>p.key).ToList())
- PTA寒假一
7-1 打印沙漏 (20 分) 本题要求你写个程序把给定的符号打印成沙漏的形状.例如给定17个"*",要求按下列格式打印 所谓"沙漏形状",是指每行输出奇数个符 ...
- robot framework中的timeout的关键词
1.默认robotframework中的含有等待的关键词(如:Wait Until Element Is Enabled),未手动设置时默认该参数为5sec 2.关键词:sleep A)一般在调试的时 ...
- requestmapping等相关知识
@responseBody注解的使用 1. @responseBody注解的作用是将controller的方法返回的对象通过适当的转换器转换为指定的格式之后,写入到response对象的body区 ...
- js正则表达式讲的最好的
https://www.cnblogs.com/chenmeng0818/p/6370819.html