在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来。

   TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,如:saver.save(sess, "/Model/model"), 执行完,在相应的目录下将会有4个文件:

    meta:文件保存的是图结构信息,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。

    ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同。是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之 前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:.data-00000-of-00001和.index文件

    checkpoint文件:checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model。加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

    保存模型时,只会保存变量的值,placeholder里面的值不会被保存。

  关于save()方法的参数记录:

      • sess:在tensorflow中,变量是存在于Session环境中,即只有在Session环境下才会存有变量值,因此,保存模型时需要传入session
      • global_step:在n次迭代后,再保存模型,只需设置global_step参数即可
      • 由于图是不变的,没必要每次都去保存,可以在多次迭代过程中只用保存一次模型即可,可以通过设置write_meta_graph=False即可
      • keep_checkpoint_every_n_hours:用来设置间隔时间来保存
      • max_to_keep: 用来设置保存最近模型文件的个数
      • 如果不想保存所有变量,而只保存一部分变量,可以通过指定variables/collections,默认是保存所有的变量。

    tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}。

  导入模型

    加载图:saver=tf.train.import_meta_graph(.meta文件)即可。

    加载模型参数:aver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
注意w1:0是tensor的name,既可以指定变量名称,也可以指定操作名称。

  其实,我们也可以只恢复图的一部分,并且再加入其它的op用于fine-tuning。只需通过graph.get_tensor_by_name()方法获取需要的op,并且在此基础上建立图即可。例如:假设我们想使用已经训练好的VGG模型,并且要更改部分层,如下:

saver = tf.train.import_meta_graph('vgg.meta')
# 访问图
graph = tf.get_default_graph() #访问用于fine-tuning的output
fc7= graph.get_tensor_by_name('fc7:0') #如果你想修改最后一层梯度,需要如下
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list() new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

Tensorflow模型保存与加载的更多相关文章

  1. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  2. 转 tensorflow模型保存 与 加载

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

  3. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  4. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  5. sklearn模型保存与加载

    sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...

  6. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

  7. TensorFlow的模型保存与加载

    import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir=&qu ...

  8. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  9. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

随机推荐

  1. python操作 windows 锁屏与锁屏状态判断

    pip install ctypes from ctypes import * while True: u = windll.LoadLibrary('user32.dll') result = u. ...

  2. 【10】Python urllib、编码解码、requests、多线程、多进程、unittest初探、__file__、jsonpath

    1 urllib urllib是一个标准模块,直接import就可以使用 1.1get请求 from urllib.request import urlopen url='http://www.nnz ...

  3. 从mysql8.0.15升级到8.0.16

    从mysql8.0.15升级到8.0.16 环境简介 操作系统:Centos 6.10 64位 目前版本:8.0.15 MySQL Community Server 二进制 目的:升级为8.0.16 ...

  4. CentOS8 中文输入法

    CentOS8发布了,安装了下试试,结果发现中文输入法调不出来. 系统安装完成后,在系统[设置]的[Region&Language]里的[输入源]里可以添加汉语输入源,但是不能打中文字. 下面 ...

  5. toJSON() 方法,将 Date 对象转换为字符串,并格式化为 JSON 数据格式。

    JavaScript toJSON() 方法 定义和用法 toJSON() 方法可以将 Date 对象转换为字符串,并格式化为 JSON 数据格式. JSON 数据用同样的格式就像x ISO-8601 ...

  6. springmvc请求参数异常统一处理,结合钉钉报告信息定位bug位置

    参考之前一篇博客:springmvc请求参数异常统一处理 1.ExceptionHandlerController package com.oy.controller; import java.tex ...

  7. 使用 XSLT 显示 XML

    通过使用 XSLT,您可以向 XML 文档添加显示信息. 使用 XSLT 显示 XML XSLT 是首选的 XML 样式表语言. XSLT (eXtensible Stylesheet Languag ...

  8. 前端iPhone X适配总结

    屏幕尺寸 垂直方向上,iPhone X的显示宽度与iPhone 6,iPhone 7 和 iPhone 8 的 4.7 英寸一样,但是比4.7英寸的显示屏高145pt. 安全区域 安全区域指的是一个可 ...

  9. 一款基于jQuery的漂亮弹出层

    特别提示:本人博客部分有参考网络其他博客,但均是本人亲手编写过并验证通过.如发现博客有错误,请及时提出以免误导其他人,谢谢!欢迎转载,但记得标明文章出处:http://www.cnblogs.com/ ...

  10. squid的处理request和reply的流程

    request处理: Breakpoint , SQUID_MD5Final ( digest= { (gdb) bt # SQUID_MD5Final ( digest= # ) at store_ ...