当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。

下面以bert的图为例,进行优化

    def optimize_graph(self, checkpoint_file, model_config):
import json
tf = self.import_tf()
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = checkpoint_file with tf.gfile.GFile(model_config, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f)) input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids') import contextlib
jit_scope = contextlib.suppress with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # get output tensor
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
reader = tf.train.NewCheckpointReader(init_checkpoint)
output_weights = reader.get_tensor('output_weights')
output_bias = reader.get_tensor('output_bias')
output_layers = model.get_pooled_output()
pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True),
output_bias))
pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def() # write graph to file
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(
tmp_g,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False) import tempfile
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString()) return tmp_file

返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。

tensorflow 优化图的更多相关文章

  1. TensorFlow的图切割模块——Graph Partitioner

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在经过TensorFlow的Placer策略模块调整之后,下一步就是根据Pla ...

  2. 现代英特尔® 架构上的 TensorFlow* 优化——正如去年参加Intel AI会议一样,Intel自己提供了对接自己AI CPU优化版本的Tensorflow,下载链接见后,同时可以基于谷歌官方的tf版本直接编译生成安装包

    现代英特尔® 架构上的 TensorFlow* 优化 转自:https://software.intel.com/zh-cn/articles/tensorflow-optimizations-on- ...

  3. TensorFlow从0到1之TensorFlow优化器(13)

    高中数学学过,函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系数.本节将介绍如何使 ...

  4. TensorFlow优化器及用法

    TensorFlow优化器及用法 函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系 ...

  5. TensorFlow优化器浅析

    本文基于tensorflow-v1.15分支,简单分析下TensorFlow中的优化器. optimizer = tf.train.GradientDescentOptimizer(learning_ ...

  6. tensorflow优化器-【老鱼学tensorflow】

    tensorflow中的优化器主要是各种求解方程的方法,我们知道求解非线性方程有各种方法,比如二分法.牛顿法.割线法等,类似的,tensorflow中的优化器也只是在求解方程时的各种方法. 比较常用的 ...

  7. tensorflow:图(Graph)的核心数据结构与通用函数(Utility function)

    Tensorflow一些常用基本概念与函数(2) 1. 图(Graph)的核心数据结构 tf.Graph.__init__:建立一个空图: tf.Graph.as_default():一个将某图设置为 ...

  8. Tensorflow 优化学习

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data pr ...

  9. tensorflow 框架图

随机推荐

  1. classmethod,staticmethod

    '''1 绑定方法: 在类内部定义的函数,默认就是给对象来用,而且是绑定给对象用的,称为对象的绑定方法 绑定对象的方法特殊之处: 应该由对象来调用,对象来调用,会自动将对象当作第一个参数传入 绑定到类 ...

  2. 20169207《Linux内核原理与分析》第六周作业

    这周的作业同样分为两部分,第一部分的学习MOOC第四节[扒开系统调用的三层皮],并结合实验楼的实验四深入学习.第二部分阅读学习教材「Linux内核设计与实现 (Linux Kernel Develop ...

  3. hide handkerchief

    Problem Description The Children’s Day has passed for some days .Has you remembered something happen ...

  4. PAT甲级 1129. Recommendation System (25)

    1129. Recommendation System (25) 时间限制 400 ms 内存限制 65536 kB 代码长度限制 16000 B 判题程序 Standard 作者 CHEN, Yue ...

  5. PHP后台图片上传作品 接口

    //把新图片添加到文件夹里 public function info($file=''){ $info = $file->validate(['ext'=>'jpg'])->rule ...

  6. AngularJS controller as vm方式

    从AngularJS1.20开始引入了Controller as 新语法,以前版本在Controller 中必须注入$scope这个服务,才能在视图绑定中使用这些变量,$scope不是那么POJO(普 ...

  7. uniGUI 通过SessionList操作另外的登录用户

    参照bbs,写了这个方法,检查是否有同名用户已经登录:procedure TUniMainModule.CheckSameUser(aUserLoginCode: string);var  ASess ...

  8. CentOS7中配置vsftpd

    1.yum -y install vsftpd  安装vsftpd 2.配置vsftpd的配置文件(/etc/vsftpd/vsftpd.conf)主要修改以下配置内容 #不允许匿名访问 anonym ...

  9. [Word]让字符重合显示

    某些时候需要让字符重合显示,比如您好二字,显示为: 需要用到word的Advance域,他可以让后面的文字上下左右移动一定的磅. 譬如上面你好的显示:word中域代码为: 意思是好字向left移动了2 ...

  10. Java获取http和https网址对应html数据实例

    由于之前在公司一直用的C#做的软件开发,近些天有同学需要用Java做一个从指定网址获取信息的Java程序.正好不是很难,顺便复习了一下Java的知识. 要求如下,在https://www.marine ...