Tensorflow模型保存与加载
在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来。
TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,如:saver.save(sess, "/Model/model"), 执行完,在相应的目录下将会有4个文件:
meta:文件保存的是图结构信息,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。
ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同。是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之 前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:.data-00000-of-00001和.index文件
checkpoint文件:checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model。加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。
保存模型时,只会保存变量的值,placeholder里面的值不会被保存。
关于save()方法的参数记录:
- sess:在tensorflow中,变量是存在于Session环境中,即只有在Session环境下才会存有变量值,因此,保存模型时需要传入session
- global_step:在n次迭代后,再保存模型,只需设置
global_step参数即可 - 由于图是不变的,没必要每次都去保存,可以在多次迭代过程中只用保存一次模型即可,可以通过设置write_meta_graph=False即可
- keep_checkpoint_every_n_hours:用来设置间隔时间来保存
- max_to_keep: 用来设置保存最近模型文件的个数
- 如果不想保存所有变量,而只保存一部分变量,可以通过指定variables/collections,默认是保存所有的变量。
tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}。
导入模型
加载图:saver=tf.train.import_meta_graph(.meta文件)即可。
加载模型参数:aver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
注意w1:0是tensor的name,既可以指定变量名称,也可以指定操作名称。
其实,我们也可以只恢复图的一部分,并且再加入其它的op用于fine-tuning。只需通过graph.get_tensor_by_name()方法获取需要的op,并且在此基础上建立图即可。例如:假设我们想使用已经训练好的VGG模型,并且要更改部分层,如下:
saver = tf.train.import_meta_graph('vgg.meta')
# 访问图
graph = tf.get_default_graph()
#访问用于fine-tuning的output
fc7= graph.get_tensor_by_name('fc7:0')
#如果你想修改最后一层梯度,需要如下
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
Tensorflow模型保存与加载的更多相关文章
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- 转 tensorflow模型保存 与 加载
使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- sklearn模型保存与加载
sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...
- TensorFlow的模型保存与加载
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir=&qu ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
随机推荐
- MySQL数据库安装和启动
目录 一.数据库介绍 二.数据库的分类 1. 关系型数据库系统 2. 当下的关系型数据库系统 3. 当下的非关系型数据库系统 4. 关系型和非关系型数据库系统的区别 三.MySQL的架构 四.MySQ ...
- 【NOIP2016提高组day1】换教室
题目 对于刚上大学的牛牛来说,他面临的第一个问题是如何根据实际情况申请合适的 课程. 在可以选择的课程中,有 2n 节课程安排在 n 个时间段上. 在第 i ( 1 ≤ i ≤ n )个 时间段上,两 ...
- 使用net命令启动MongoDB服务发生系统错误,返回值为5
使用net命令启动MongoDB服务发生系统错误,返回值为5 错误的截图如下: 解决的方案是以管理员的身份运行命令窗口,参考如下: https://www.cnblogs.com/fanblogs/p ...
- docker跨主机通信-overlay
使用consul 1,让两个网络环境下的容器互通,那么必然涉及到网络信息的同步,所以需要先配置一下consul. 直接运行下面命令.启动consul. docker run -d -p 8500:85 ...
- 【HDU6667】Roundgod and Milk Tea【贪心】
题目大意:给你ai,bi,限制ai不能流向bi,求最大流 题解:贪心,对于第i个班级,考虑前i-1个班级匹配完剩余多少a,b,将这些ab对第i个班级进行贪心匹配 匹配完若第i个班级还有剩余的ab,考虑 ...
- windows10 gcc编译C程序(简单编译)
参考:http://c.biancheng.net/view/660.html gcc可以一次性完成C语言源程序的编译,也可以分步骤完成:下面先介绍一次性编译过程. 1.生成可执行程序 cd xxx ...
- You Only Look Once Unified, Real-Time Object Detection(你只需要看一次统一的,实时的目标检测)
我们提出了一种新的目标检测方法YOLO.先前的目标检测工作重新利用分类器来执行检测.相反,我们将目标检测作为一个回归问题来处理空间分离的边界框和相关的类概率.单个神经网络在一次评估中直接从完整图像预测 ...
- 【转载】What is the difference between authorized_keys and known_hosts file for SSH?
The known_hosts file lets the client authenticate the server, to check that it isn't connecting to a ...
- wannafly 练习赛11 E 求最值(平面最近点对)
链接:https://www.nowcoder.com/acm/contest/59/E 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 32768K,其他语言65536K 64bit ...
- [UPC10525]:Dove打扑克(暴力+模拟)
题目描述 $Dove$和$Cicada$是好朋友,他们经常在一起打扑克来消遣时光,但是他们打的扑克有不同的玩法. 最开始时,牌桌上会有$n$个牌堆,每个牌堆有且仅有一张牌,第$i$个牌堆里里里那个扑克 ...