参考:https://github.com/tflearn/tflearn/issues/964

解决方法:

"""
Tensorflow graph freezer
Converts Tensorflow trained models in .pb Code adapted from:
https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py
""" import os, argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
import tensorflow as tf
from tensorflow.python.framework import graph_util def freeze_graph(model_folder,output_graph="frozen_model.pb"):
# We retrieve our checkpoint fullpath
try:
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
print("[INFO] input_checkpoint:", input_checkpoint)
except:
input_checkpoint = model_folder
print("[INFO] Model folder", model_folder) # Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
output_node_names = "FullyConnected/Softmax" # NOTE: Change here # We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True # We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def() # We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
) # Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node)) print("[INFO] output_graph:",output_graph)
print("[INFO] all done") if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Tensorflow graph freezer\nConverts trained models to .pb file",
prefix_chars='-')
parser.add_argument("--mfolder", type=str, help="model folder to export")
parser.add_argument("--ograph", type=str, help="output graph name", default="frozen_model.pb") args = parser.parse_args()
print(args,"\n") freeze_graph(args.mfolder,args.ograph) # However, before doing model.save(...) on TFLearn i have to do
# ************************************************************
# del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
# ************************************************************ """
Then I call this command
python tf_freeze.py --mfolder=<path_to_tflearn_model> Note The <path_to_tflearn_model> must not have the ".data-00000-of-00001".
The output_node_names variable may change depending on your architecture. The thing is that you must reference the layer that has the softmax activation function.
"""

注意:

1、需要在 tflearn的model.save 前:

del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

作用:去除模型里训练OP。

参考:https://github.com/tflearn/tflearn/issues/605#issuecomment-298478314

2、如果是有batch normalzition,或者残差网络层,会出现:

Error when loading the frozen graph with tensorflow.contrib.layers.python.layers.batch_norm
ValueError: graph_def is invalid at node u'BatchNorm/cond/AssignMovingAvg/Switch': Input tensor 'BatchNorm/moving_mean:0' Cannot convert a tensor of type float32 to an input of type float32_ref
freeze_graph.py doesn't seem to store moving_mean and moving_variance properly

An ugly way to get it working:
manually replace the wrong node definitions in the frozen graph
RefSwitch --> Switch + add '/read' to the input names
AssignSub --> Sub + remove use_locking attributes

则需要在restore模型后加入:

# fix batch norm nodes
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']

参考:https://github.com/tensorflow/tensorflow/issues/3628

I met the same issue when I was trying to export graph and variables by saved_model module. And finally I found a walk around to fix this issue:

Remove the TRAIN_OPS collections from graph collection. e.g.:

with dnn.graph.as_default():
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

The dumped graph may not be available for training again (by tflearn), but should be able to perform prediction and evaluation. This is useful when serving model by another module or language (e.g. tensorflow serving or tensorflow go binding). I'll do more further tests about this.

If you wanna re-train the model, please use the builtin "save" method and re-construction the graph and load the saved data when re-training.

2、可能需要在代码修改这行,

output_node_names = "FullyConnected/Softmax" # NOTE: Change here

参考:https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py

@vparikh10 @ratfury @rakashi I faced the same situation just like you.
From what I understood, you may have to change this line according to your network definition.
In my case, instead of having output_node_names = "Accuracy/prediction", I have output_node_names = "FullyConnected_2/Softmax".



I made this change after reading this suggestion

对我自己而言,写成softmax或者Softmax都是不行的!然后我将所有的node names打印出来:
打印方法:
    with tf.Session() as sess:
model = get_cnn_model(max_len, volcab_size)
model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True, batch_size=1000, n_epoch=1)
init_op = tf.initialize_all_variables()
sess.run(init_op) for v in sess.graph.get_operations():
print(v.name)

然后确保output_node_names在里面。


附:gist里的代码,将output node names转换为参数
import os, argparse

import tensorflow as tf

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph dir = os.path.dirname(os.path.realpath(__file__)) def freeze_graph(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""
if not tf.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exists. Please specify an export "
"directory: %s" % model_dir) if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1 # We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path # We precise the file fullname of our freezed graph
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + "/frozen_model.pb" # We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True # We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # We restore the weights
saver.restore(sess, input_checkpoint) # We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
) # Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node)) return output_graph_def if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="", help="Model folder to export")
parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.")
args = parser.parse_args() freeze_graph(args.model_dir, args.output_node_names)

												

将tflearn的模型保存为pb,给TensorFlow使用的更多相关文章

  1. tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛

    tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls0  1  10  11  12  13  14  15 ...

  2. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  3. TensorFlow模型保存和提取方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将Tens ...

  4. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  5. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  6. Tensorflow模型保存与加载

    在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...

  7. tensorflow 三种模型:ckpt、pb、pb-savemodel

    1.CKPT 目录结构 checkpoint: model.ckpt-1000.index model.ckpt-1000.data-00000-of-00001 model.ckpt-1000.me ...

  8. [MISS静IOS开发原创文摘]-AppDelegate存储全局变量和 NSUserDefaults standardUserDefaults 通过模型保存和读取数据,存储自定义的对象

    由于app开发的需求,需要从api接口获得json格式数据并保存临时的 app的主题颜色 和 相关url 方案有很多种: 1, 通过AppDelegate保存为全局变量,再获取 2,使用NSUSerD ...

  9. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

随机推荐

  1. service里设置websocket心跳并向fragment发送数据

    垃圾小白写了自己看的 /** * service 文件 */ public class SocketService extends Service { //自己定义接口用来传参 private sta ...

  2. EF code first Acceleration - CodeFirst 加速

    EntityFramework Code First 用起来很方便,可是有时感觉卡,就是有点慢.可以采用以下措施来加速一下,原来取出1万条记录并显示在Winform窗体上第一次需要1.9秒的时间,加速 ...

  3. html5——网络状态

    我们可以通过window.onLine来检测,用户当前的网络状况,返回一个布尔值 window.addEventListener("online",function(){ aler ...

  4. CNN-CV识别简史2012-2017:从 AlexNet、ResNet 到 Mask RCNN

    原文:计算机视觉识别简史:从 AlexNet.ResNet 到 Mask RCNN 总是找不到原文,标记一下.        一切从这里开始:现代物体识别随着ConvNets的发展而发展,这一切始于2 ...

  5. SpringMVC知识点总结一(非注解方式的处理器与映射器配置方法)

    一.SpringMVC处理请求原理图(参见以前博客) 1.  用户发送请求至前端控制器DispatcherServlet 2.  DispatcherServlet收到请求调用HandlerMappi ...

  6. ubuntu14.0开机guest账号禁用方法

    在终端里进入/usr/share/lightdm/lightdm.conf.d/目录 sudo vim 50-unity-greeter.conf 然后在文件里输入: [SeatDefaults] a ...

  7. Vue(八)全局变量和全局方法

    一.在main.js同级目录建立一个common.js文件 // 全局变量 const globalObj = {}; // 定义公共变量 globalObj.name = '小明'; // 定义公共 ...

  8. [系统资源攻略]CPU

    linux系统中如何查看cpu信息? 查看linux版本.cpu.位数.内核.内存等信息 linux下查看CPU,内存,机器型号,网卡等信息的方法 查看服务器物理CPU数和CPU核数方法介绍 可以用/ ...

  9. 第一节:web爬虫之requests

    Requests库是用Python编写的,并且Requests是一个优雅而简单的Python HTTP库,在使用Requests库时更加方便,可以节约我们大量的工作,完全满足HTTP测试需求.

  10. BZOJ 1016 最小生成树计数 【模板】最小生成树计数

    [题解] 对于不同的最小生成树,每种权值的边使用的数量是一定的,每种权值的边的作用是确定的 我们可以先做一遍Kruskal,求出每种权值的边的使用数量num 再对于每种权值的边,2^num搜索出合法使 ...