学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节。

具体见我提出的问题:https://www.tensorflowers.cn/t/5314

参考https://blog.csdn.net/zhuiqiuk/article/details/53376283后,对代码进行了修改。

问题的跟踪情况记录:

1 首先是保存模型:

import tensorflow as tf
from tensorflow.python.framework import graph_util
v1=tf.constant([10000.0],name='v1')
#v1 = tf.placeholder(tf.float32,shape=[1],name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2 init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op,{v1:[100]})
print (sess.run(result,{v1:[1000]}))
writer = tf.summary.FileWriter('./graphs/model_graph', sess.graph)
writer.close()
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def,['add'])
with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())

因为inception v3接受输入的tensor是Decode/Content:0,是一个const类型,就是tf.constant类型,而一开始,我并不明白问题的所在,就将tf.placeholder改为了tf.constant,而实际上,两个都可以。问题的本身不是出在这里,而是对书本有错误的理解。

书上因为获取的是两个return elements,会自动从列表中取出元素。

而我获得的是一个retrun elelment,则只能返回一个列表。

2 使用并加载持久化模型,直接调用模型的训练参数进行计算。

#这是我以前写的程序,是错误的
"""
因为我以前写的只是获取一个值。
而现在修正的v1则是一个tensor。我们可以修正tensor的值。
所以,tensorflow 实战google深度学习框架中有重大bug。
不懂的联系我手机 18627711314 杰
"""
import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = "Saved_model/combined_model.pb"
#model_filename = "inception_dec_2015/tensorflow_inception_graph.pb" with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
"""将模型的相关信息写入文件,和利用tensorboard进行可视化
f = open("xiaojie2.txt", "w")
print ("xiaojie2\n",file = f)
print (graph_def,file=f)
f.close()
writer = tf.summary.FileWriter('./graphs/model_graph2', graph_def)
writer.close()
"""
"""④输出所有可训练的变量名称,也就是神经网络的参数"""
trainable_variables=tf.trainable_variables()
variable_list_name = [c.name for c in tf.trainable_variables()]
variable_list = sess.run(variable_list_name)
for k,v in zip(variable_list_name,variable_list):
print("variable name:",k)
print("shape:",v.shape)
#print(v)
"""④输出所有可训练的变量名称,也就是神经网络的参数""" v1= tf.import_graph_def(graph_def, return_elements=["v1:0"])
print (v1)
print (sess.run(v1))
v2= tf.import_graph_def(graph_def, return_elements=["v2:0"])
print (sess.run(v2))
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print (sess.run(result))
x=np.array([2000.0])
print (sess.run(result,feed_dict={v1: x}))

这些都是参照书上使用inception模型时的做法,我参照着自己写了一个模型,但是有重大bug

运行结果总是提示:

因为总是无法用feed_dict传入我想计算的输入。我一开始因为是tf.placeholder的原因,就参照inception的改为tf.constant。但是还是不行。后来在网上看到别人加载pb文件的一段代码https://blog.csdn.net/zhuiqiuk/article/details/53376283,重新对代码进行了修正。如下:

import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
output_graph_path='Saved_model/combined_model.pb'
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
"""④输出所有可训练的变量名称,也就是神经网络的参数"""
trainable_variables=tf.trainable_variables()
variable_list_name = [c.name for c in tf.trainable_variables()]
variable_list = sess.run(variable_list_name)
for k,v in zip(variable_list_name,variable_list):
print("variable name:",k)
print("shape:",v.shape)
#print(v)
"""④输出所有可训练的变量名称,也就是神经网络的参数"""
input_x = sess.graph.get_tensor_by_name("v1:0")
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print (input_x)
x=np.array([2000.0])
print (sess.run(result,feed_dict={input_x: x}))

问题的根本在于:

V1= tf.import_graph_def(graph_def, return_elements=["v1:0"])获取的是

[<tf.Tensor 'import/v1:0' shape=(1,) dtype=float32>],是一个列表

而:input_x = sess.graph.get_tensor_by_name("v1:0")

获取的是一个Tensor,即Tensor("v1:0", shape=(1,), dtype=float32)。

使用sess.run的时候,feed_dict要修正的是tensor,而不是一个list。因此,总会提出unhashable type:listd的报错。

3 将原始错误程序的代码改为:

v1= tf.import_graph_def(graph_def, return_elements=["v1:0"])
print (v1)
print (sess.run(v1))
v2= tf.import_graph_def(graph_def, return_elements=["v2:0"])
print (sess.run(v2))
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print (sess.run(result))
x=np.array([2000.0]) #print (sess.run(result,feed_dict={v1: x}))
print (sess.run(result,feed_dict={v1[0]: x}))

也可以正确运行

4 后来正确的程序还可以改为:

import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
output_graph_def = tf.GraphDef()
output_graph_path='Saved_model/combined_model.pb'
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
"""④输出所有可训练的变量名称,也就是神经网络的参数"""
trainable_variables=tf.trainable_variables()
variable_list_name = [c.name for c in tf.trainable_variables()]
variable_list = sess.run(variable_list_name)
for k,v in zip(variable_list_name,variable_list):
print("variable name:",k)
print("shape:",v.shape)
#print(v)
"""④输出所有可训练的变量名称,也就是神经网络的参数"""
input_x = sess.graph.get_tensor_by_name("v1:0")
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print (input_x)
x=np.array([2000.0])
print (sess.run(result,feed_dict={input_x: x}))

需要注意的是:

首先,无论如何,加载pb以后,输出所有可训练的变量,都不可能输出持久化模型中的变量。这一点以前就说过。以前说过,只能使用train.saver的方式。

其次,如果使用后一种方式,即sess.graph.get_tensor_by_name,则必须要有红黄标注的那一幕。即:_ = tf.import_graph_def(output_graph_def, name="")

程序附件

链接:https://pan.baidu.com/s/11YtyDEyV84jONPi9tO2TCw 密码:8mfj

1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)的更多相关文章

  1. AI - TensorFlow - 示例05:保存和恢复模型

    保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...

  2. 第六节,TensorFlow编程基础案例-保存和恢复模型(中)

    在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...

  3. 吴裕雄 python 神经网络——TensorFlow pb文件保存方法

    import tensorflow as tf from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.const ...

  4. TebsorFlow低阶API(五)—— 保存和恢复

    简介 tf.train.Saver 类提供了保存和恢复模型的方法.通过 tf.saved_model.simple_save 函数可以轻松地保存适合投入使用的模型.Estimator会自动保存和恢复 ...

  5. tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测

    由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...

  6. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  7. tensorflow 1.0 学习:模型的保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  8. 保存与恢复变量和模型,tensorflow官方文档阅读笔记

    官方中文文档的网址先贴出来:https://tensorflow.google.cn/programmers_guide/saved_model tf.train.Saver 类别提供了保存和恢复模型 ...

  9. TensorFlow学习笔记:保存和读取模型

    TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...

随机推荐

  1. AC自动机模板题

    AC自动机学习博客 AC自动机理解要点: 1)fail指针指向的是每个节点,在字典树上和这个节点后缀相同的最长单词,每次都这样匹配,必定不会漏过答案. 2)字典树建立后,会在bfs求fail阶段把字典 ...

  2. C#根据工作经验来谈谈面向对象

    C#面向对象的三大特性:封装.继承.多态. 这是一种特性,更是官方给我们的学习语法,但是我们根据过去的经验来思考一下, 到底什么是面向对象? 面向对象在我们实际开发中到底起着什么作用? 我们什么时候要 ...

  3. PIE SDK与OpenCV结合说明文档

    1.功能简介 OpenCV是基于BSD许可(开源)发行的跨平台计算机视觉库,可以运行在Linux.Windows.Android和Mac OS操作系统上.它轻量级而且高效——由一系列 C 函数和少量 ...

  4. 转 oracle ASM中ASM_POWER_LIMIT参数

    https://daizj.iteye.com/blog/1753434 ASM_POWER_LIMIT 该初始化参数用于指定ASM例程平衡磁盘所用的最大权值,其数值范围为0~11,默认值为1.该初始 ...

  5. asp.net WebService技术简介

    1.什么是web service? 这里借助百度百科专业的解释:web service是一个平台独立的.低耦合的.自包含的.基于可编程的web应用程序(说简单点,就是使用web service继续不需 ...

  6. CentOS下安装官方RPM包的MySQL后找不到my.cnf

    PS:昨天一同事问我说他用MySQL 5.5官方的rpm包安装了,但是在/etc/下面没有my.cnf配置文件.官方rpm包安装的mysql默认确实是没有/etc/my.cnf. 为什么没有这个文件而 ...

  7. table定位

    Table定位 在 web 页面中经常会遇到 table 表格,特别是后台操作页面比较常见.本篇详细讲解 table 表格如何定位. 1.1 table特性 1.table 页面查看源码一般有这几个明 ...

  8. Ubuntu14.04下编译安装或apt-get方式安装搭建Apache或Httpd服务(图文详解)

    不多说,直接上干货! 写在前面的话 对于 在Ubuntu系统上,编译安装Apache它默认路径是在/usr/local/apache2/htdocs 或者编译安装httpd它默认路径是在/usr/lo ...

  9. Git学习系列之 Git 、CVS、SVN的比较

    Git .CVS.SVN比较 项目源代码的版本管理工具中,比较常用的主要有:CVS.SVN.Git 和 Mercurial  (其中,关于SVN,请参见我的博客:SVN学习系列) 目前Google C ...

  10. IE678下,select 诡异的样式

    我没有IE6,我用IE5测试的. IE5下的测试结果:貌似只能设置 width ,设置其他的都失效,连 height 都不例外. IE7下的测试结果:垂直居中失效.边框失效,宽高生效. IE8下的测试 ...