一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),实际在这个文件目录下会生成4个人文件:

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt"),加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")。

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

[python] view plain copy

 
  1. # 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行
  2. # 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确
  3. # Part1: 通过tf.train.Saver类实现保存和载入神经网络模型
  4. # 执行本段程序时注意当前的工作路径
  5. import tensorflow as tf
  6. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
  7. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
  8. result = v1 + v2
  9. saver = tf.train.Saver()
  10. with tf.Session() as sess:
  11. sess.run(tf.global_variables_initializer())
  12. saver.save(sess, "Model/model.ckpt")
  13. # Part2: 加载TensorFlow模型的方法
  14. import tensorflow as tf
  15. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
  16. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
  17. result = v1 + v2
  18. saver = tf.train.Saver()
  19. with tf.Session() as sess:
  20. saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"
  21. print(sess.run(result)) # [ 3.]
  22. # Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图
  23. import tensorflow as tf
  24. saver = tf.train.import_meta_graph("Model/model.ckpt.meta")
  25. with tf.Session() as sess:
  26. saver.restore(sess, "./Model/model.ckpt") # 注意路径写法
  27. print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]
  28. # Part4: tf.train.Saver类也支持在保存和加载时给变量重命名
  29. import tensorflow as tf
  30. # 声明的变量名称name与已保存的模型中的变量名称name不一致
  31. u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
  32. u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
  33. result = u1 + u2
  34. # 若直接生命Saver类对象,会报错变量找不到
  35. # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}
  36. # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中
  37. saver = tf.train.Saver({"v1": u1, "v2": u2})
  38. with tf.Session() as sess:
  39. saver.restore(sess, "./Model/model.ckpt")
  40. print(sess.run(result)) # [ 3.]
  41. # Part5: 保存滑动平均模型
  42. import tensorflow as tf
  43. v = tf.Variable(0, dtype=tf.float32, name="v")
  44. for variables in tf.global_variables():
  45. print(variables.name) # v:0
  46. ema = tf.train.ExponentialMovingAverage(0.99)
  47. maintain_averages_op = ema.apply(tf.global_variables())
  48. for variables in tf.global_variables():
  49. print(variables.name) # v:0
  50. # v/ExponentialMovingAverage:0
  51. saver = tf.train.Saver()
  52. with tf.Session() as sess:
  53. sess.run(tf.global_variables_initializer())
  54. sess.run(tf.assign(v, 10))
  55. sess.run(maintain_averages_op)
  56. saver.save(sess, "Model/model_ema.ckpt")
  57. print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]
  58. # Part6: 通过变量重命名直接读取变量的滑动平均值
  59. import tensorflow as tf
  60. v = tf.Variable(0, dtype=tf.float32, name="v")
  61. saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
  62. with tf.Session() as sess:
  63. saver.restore(sess, "./Model/model_ema.ckpt")
  64. print(sess.run(v)) # 0.0999999
  65. # Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典
  66. import tensorflow as tf
  67. v = tf.Variable(0, dtype=tf.float32, name="v")
  68. # 注意此处的变量名称name一定要与已保存的变量名称一致
  69. ema = tf.train.ExponentialMovingAverage(0.99)
  70. print(ema.variables_to_restore())
  71. # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
  72. # 此处的v取自上面变量v的名称name="v"
  73. saver = tf.train.Saver(ema.variables_to_restore())
  74. with tf.Session() as sess:
  75. saver.restore(sess, "./Model/model_ema.ckpt")
  76. print(sess.run(v)) # 0.0999999
  77. # Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中
  78. import tensorflow as tf
  79. from tensorflow.python.framework import graph_util
  80. v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
  81. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
  82. result = v1 + v2
  83. with tf.Session() as sess:
  84. sess.run(tf.global_variables_initializer())
  85. # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分
  86. graph_def = tf.get_default_graph().as_graph_def()
  87. output_graph_def = graph_util.convert_variables_to_constants(sess,
  88. graph_def, ['add'])
  89. with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:
  90. f.write(output_graph_def.SerializeToString())
  91. # Part9: 载入包含变量及其取值的模型
  92. import tensorflow as tf
  93. from tensorflow.python.platform import gfile
  94. with tf.Session() as sess:
  95. model_filename = "Model/combined_model.pb"
  96. with gfile.FastGFile(model_filename, 'rb') as f:
  97. graph_def = tf.GraphDef()
  98. graph_def.ParseFromString(f.read())
  99. result = tf.import_graph_def(graph_def, return_elements=["add:0"])
  100. print(sess.run(result)) # [array([ 3.], dtype=float32)]

TensorFlow模型保存和提取方法的更多相关文章

  1. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  2. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  3. TensorFlow 模型保存/载入

    我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.jobl ...

  4. Tensorflow模型保存与加载

    在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...

  5. 10 Tensorflow模型保存与读取

    我们的模型训练出来想给别人用,或者是我今天训练不完,明天想接着训练,怎么办?这就需要模型的保存与读取.看代码: import tensorflow as tf import numpy as np i ...

  6. 一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)

    该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196 我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大 ...

  7. 转 tensorflow模型保存 与 加载

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

  8. tensorflow 模型保存后的加载路径问题

    import tensorflow as tf #保存模型 saver = tf.train.Saver() saver.save(sess, "e://code//python//test ...

  9. tensorflow 模型保存

    1.首先 saver = tf.train.Saver(max_to_keep=1)新建一个saver,max_to_keep是说只保留最后一轮的训练结果 2.使用save方法保存模型 saver.s ...

随机推荐

  1. 开源项目Universal Image Loader for Android 说明文档 (1) 简介

     When developing applications for Android, one often facesthe problem of displaying some graphical ...

  2. canvas 创建渐变图形

    <!DOCTYPE html> <html> <head lang="en"> <meta charset="UTF-8&quo ...

  3. 学习Java有没有什么捷径?

    很多网友咨询学习Java有没有什么捷径,我说“ 无他,唯手熟尔 ”.但是愿意将一些经验写出来,以便后来者少走弯路,帮助别人是最大的快乐嘛! 要想学好Java,首先要知道Java的大致分类. 我们知道, ...

  4. Django之HttpRequest和HttpReponse

    当一个web请求链接进来时,django会创建一个HttpRequest对象来封装和保存所有请求相关的信息,并且会根据请求路由载入匹配的试图函数,每个请求的试图函数都会返回一个HttpResponse ...

  5. 剑指offer-第五章优化时间和空间效率(数组中的逆序对的总数)

    题目:在数组中如果两个数字的前面的数比后面的数大,则称为一对逆序对.输入一个数组求出数组中逆序对的总数. 以空间换时间:思路:借助一个辅助数组,将原来的数组复制到该数组中.然后将该数组分成子数组,然后 ...

  6. UVA11174 Stand in a Line

    题意 PDF 分析 \[ f(i)=f(c_1)f(c_2)\dots\times(s(i)-1)!/(s(c_1)!s(c_2)! \dots s(c_k)! )\\ f(root)=(s(root ...

  7. eclipse share project到svn时显示不被信任的证书,暂时接受也不行

    svn: 方法 OPTIONS 失败于 “https://eping.net/svn/testproject”: SSL handshake failed: SSL 错误:在证书中检测到违规的密钥用法 ...

  8. hadoop复合键排序使用方法

    在hadoop中处理复杂业务时,需要用到复合键,复合不同于单纯的继承Writable接口,而是继承了 WritableComparable<T>接口,而实际上,WritableCompar ...

  9. 解决----Word无法创建工作文件,请检查临时环境变量

    用户在运行Word2003或打开Word2003文档时,可能会出现“Word无法创建工作文件,请检查临时环境变量”的错误提示,此问题主要是由于Word2003的用户设置出现损坏而造成的.网上针对此问题 ...

  10. Java各种集合容器的总结

    Java容器指的是List,Set,Map这些类.由于翻译的问题,问到集合,Collection这些指的都是它们几个. List ArrayList 随机访问快 LinkedList 插入删除快 这个 ...