原文地址:

https://zhuanlan.zhihu.com/p/31308381

-------------------------------------------------------------------------------------

图(graph)是 tensorflow 用于表达计算任务的一个核心概念。从前端(python)描述神经网络的结构,到后端在多机和分布式系统上部署,到底层 Device(CPU、GPU、TPU)上运行,都是基于图来完成。然而我在实际使用过程中遇到了三对API,

  1. tf.train.Saver()  /  saver.restore()
  2. tf.train.export_meta_graph  /  tf.train.Import_meta_graph
  3. tf.train.write_graph() / tf.Import_graph_def()

他们都是用于对图的保存和恢复。同一个计算框架,为什么需要三对不同的API呢?他们保存/恢复的图在使用时又有什么区别呢?初学的时候,常常闹不清楚他们的区别,以至常常写出了错误的程序,经过一番研究,在本文中对Tensorflow中围绕Graph的核心概念进行了总结。

Graph

首先介绍一下关于 Tensorflow 中 Graph 和它的序列化表示 Graph_def。在Tensorflow的官方文档中,Graph 被定义为“一些 Operation 和 Tensor 的集合”。例如我们表达如下的一个计算的 python代码,

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2

就会生成相应的一张图,在Tensorboard中看到的图大概如下这样。其中每一个圆圈表示一个Operation(输入处为Placeholder),椭圆到椭圆的边为Tensor,箭头的指向表示了这张图
Operation 输入输出 Tensor 的传递关系。

这张图所表达的数据流 与 python 代码中所表达的计算是对应的关系(为了称呼方便,我们下面将这张由Python表达式所描述的数据流动关系叫做 Python Graph)。然而在真实的 Tensorflow 运行中,Python 构建的“图”并不是启动一个Session之后始终不变的东西。因为Tensorflow在运行时,真实的计算会被下放到多CPU上,或者 GPU 等异构设备,或者ARM等上进行高性能/能效的计算。单纯使用 Python 肯定是无法有效完成的。实际上,Tensorflow而是首先将 python 代码所描绘的图转换(即“序列化”)成 Protocol Buffer,再通过 C/C++/CUDA 运行 Protocol Buffer 所定义的图。(Protocol Buffer的介绍可以参考这篇文章学习:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/

GraphDef

从 python Graph中序列化出来的图就叫做 GraphDef(这是一种不严格的说法,先这样进行理解)。而 GraphDef 又是由许多叫做 NodeDef 的 Protocol Buffer 组成。在概念上 NodeDef 与 (Python Graph 中的)Operation 相对应。如下就是 GraphDef 的 ProtoBuf,由许多node组成的图表示。这是与上文 Python 图对应的 GraphDef:


node {
name: "Placeholder" # 注释:这是一个叫做 "Placeholder" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}

node {
name: "Placeholder_1" # 注释:这是一个叫做 "Placeholder_1" 的node
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}

node {
name: "mul" # 注释:一个 Mul(乘法)操作
op: "Mul"
input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1)
input: "Placeholder_1" # 作为这个Node的输入
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
 

以上三个 NodeDef 定义了两个Placeholder和一个Multiply。Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状。Multiply通过 input 属性定义了两个placeholder作为其输入。无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息。其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息。

那么既然 tf.Operation 的序列化 ProtoBuf 是 NodeDef,那么 tf.Variable 呢?在这个 GraphDef 中只有网络的连接信息,却没有任何 Variables呀?没错,Graphdef
中不保存任何 Variable 的信息,所以如果我们从 graph_def 来构建图并恢复训练的话,是不能成功的。
比如以下代码,

with tf.Graph().as_default() as graph:
tf.import_graph_def("graph_def_path")
saver= tf.train.Saver()
with tf.Session() as sess:
tf.trainable_variables()

其中 tf.trainable_variables() 只会返回一个空的list。Tf.train.Saver() 也会报告 no variables to save。

然而,在实际线上 inference 中,通常就是使用 GraphDef。然而,GraphDef中连Variable都没有,怎么存储weight呢?原来GraphDef 虽然不能保存 Variable,但可以保存 Constant 。通过 tf.constant 将 weight 直接存储在 NodeDef 里,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef 里面,并将该图导出为 Proto。可以查看以下链接获取更多信息,

tensorflow.org/extend/tool_developers/

https://www.tensorflow.org/mobile/prepare_models

tf.train.write_graph()     与    tf.Import_graph_def() 就是用来进行 GraphDef 读写的API。那么,我们怎么才能从序列化的图中,得到 Variables呢?这就要学习下一个重要概念,MetaGraph。

MetaGraph

Meta graph 的官方解释是:一个Meta Graph 由一个计算图和其相关的元数据构成。其包含了用于继续训练,实施评估和(在已训练好的的图上)做前向推断的信息。(A MetaGraph consists of both a computational graph and its associated metadata. A MetaGraph contains the information required to continue training, perform evaluation, or run inference on a previously trained graph. From <https://www.tensorflow.org/versions/r1.1/programmers_guide/> )

这一段看的云里雾里,不过这篇文章(https://www.tensorflow.org/versions/r1.1/programmers_guide/meta_graph)进一步解释说,Meta Graph在具体实现上就是一个 MetaGraphDef (同样是由 Protocol Buffer来定义的)。其包含了四种主要的信息,根据Tensorflow官网,这四种 Protobuf 分别是

    1. MetaInfoDef,存一些元信息(比如版本和其他用户信息)
    2. GraphDefMetaGraph的核心内容之一,我们刚刚介绍过
    3. SaverDef,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)
    4. CollectionDef 任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph 后取回。(比如“train_op”,"prediction"
      等等)

在以上四种 ProtoBuf 里面,1 和 3 都比较容易理解,2 刚刚总结过。这里特别要讲一下 Collection(CollectionDef是对应的ProtoBuf)。

Tensorflow中并没有一个官方的定义说 collection 是什么。简单的理解,它就是为了方便用户对图中的操作和变量进行管理,而创建的一个概念。它可以说是一种“集合”,通过一个 key(string类型)来对一组 Python 对象进行命名的集合。这个key既可以是tensorflow在内部定义的一些key,也可以是用户自己定义的名字(string)。

Tensorflow 内部定义了许多标准 Key,全部定义在了 tf.GraphKeys 这个类中。其中有一些常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES 等等。tf.trainable_variables()tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 是等价的;tf.global_variables()tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 是等价的。

对于用户定义的 key,我们举一个例子。例如:

pred = model_network(X)
loss=tf.reduce_mean(…, pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)

这样一段 Tensorflow程序,用户希望特别关注 pred, loss, train_op 这几个操作,那么就可以使用如下代码,将这几个变量加入到 collection 中去。(假设我们将其命名为 “training_collection”)

tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)

并且可以通过 Train_collect = tf.get_collection(“training_collection”) 得到一个python list,其中的内容就是 pred, loss, train_op的 Tensor。这通常是为了在一个新的 session 中打开这张图时,方便我们获取想要的操作。比如我们可以直接工通过get_collection() 得到 train_op,然后通过 sess.run(train_op)来开启一段训练,而无需重新构建 loss 和optimizer。

通过export_meta_graph保存图,并且通过 add_to_collection 将 train_op 加入到 collection中:

with tf.Session() as sess:
pred = model_network(X)
loss=tf.reduce_mean(…,pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
tf.add_to_collection("training_collection", train_op)
Meta_graph_def =
tf.train.export_meta_graph(tf.get_default_graph(), 'my_graph.meta')

通过 import_meta_graph将图恢复(同时初始化为本 Session的 default 图),并且通过 get_collection 重新获得 train_op,以及通过 train_op 来开始一段训练( sess.run() )。

with tf.Session() as new_sess:
tf.train.import_meta_graph('my_graph.meta')
train_op = tf.get_collection("training_collection")[0]
new_sess.run(train_op)

更多的代码例子可以在这篇文档(https://www.tensorflow.org/api_guides/python/meta_graph)中的 Import a MetaGraph 章节中看到。

那么,从 Meta Graph 中恢复构建的图可以被训练吗?是可以的。Tensorflow的官方文档 https://www.tensorflow.org/api_guides/python/meta_graph 说明了使用方法。这里要特殊的说明一下,Meta Graph中虽然包含Variable的信息,却没有 Variable 的实际值。所以从Meta Graph中恢复的图,其训练是从随机初始化的值开始的。训练中Variable的实际值都保存在check-point中,如果要从之前训练的状态继续恢复训练,就要从check-point中restore。进一步读一下Export Meta Graph的代码,可以看到,事实上variables并没有被export到meta_graph 中

github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/training/saver.py (1872行)

github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/meta_graph.py (829,845行)

export_meta_graph/Import_meta_graph 就是用来进行 Meta Graph 读写的API。

tf.train.saver.save() 在保存check-point的同时也会保存Meta Graph。但是在恢复图时,tf.train.saver.restore() 只恢复 Variable,如果要从MetaGraph恢复图,需要使用 import_meta_graph。这是其实为了方便用户,有时我们不需要从MetaGraph恢复的图,而是需要在 python 中构建神经网络图,并恢复对应的 Variable。

Check-point

Check-point 里全面保存了训练某时间截面的信息,包括参数,超参数,梯度等等。tf.train.Saver()/saver.restore() 则能够完完整整保存和恢复神经网络的训练。Check-point分为两个文件保存Variable的二进制信息。ckpt文件保存了Variable的二进制信息,index
文件用于保存 ckpt 文件中对应 Variable 的偏移量信息。

总结

Tensorflow 三种 API 所保存和恢复的图是不一样的。这三种图是从Tensorflow框架设计的角度出发而定义的。但是从用户的角度来看,TF文档的写作难免有些云里雾里,弄不清他们的区别。需要读一读Tensorflow的代码,做一些实验来对他们进行辨析。

简而言之,Tensorflow 在前端 Python 中构建图,并且通过将该图序列化到 ProtoBuf GraphDef,以方便在后端运行。在这个过程中,图的保存、恢复和运行都通过 ProtoBuf 来实现。GraphDef,MetaGraph,以及Variable,Collection 和 Saver 等都有对应的 ProtoBuf 定义。ProtoBuf 的定义也决定了用户能对图进行的操作。例如用户只能找到Node的前一个Node,却无法得知自己的输出会由哪个Node接收。

【转载】 TensorFlow - 框架实现中的三种 Graph图结构的更多相关文章

  1. TensorFlow - 框架实现中的三种 Graph

    文章目录 TensorFlow - 框架实现中的三种 Graph 1. Graph 2. GraphDef 3. MetaGraph 4. Checkpoint 5. 总结 TensorFlow - ...

  2. Tensorflow框架实现中的“三”种图

    https://zhuanlan.zhihu.com/p/31308381 图(graph)是 tensorflow 用于表达计算任务的一个核心概念.从前端(python)描述神经网络的结构,到后端在 ...

  3. Java三大框架之——Hibernate中的三种数据持久状态和缓存机制

    Hibernate中的三种状态   瞬时状态:刚创建的对象还没有被Session持久化.缓存中不存在这个对象的数据并且数据库中没有这个对象对应的数据为瞬时状态这个时候是没有OID. 持久状态:对象经过 ...

  4. 【转载】C#批量插入数据到Sqlserver中的三种方式

    引用:https://m.jb51.net/show/99543 这篇文章主要为大家详细介绍了C#批量插入数据到Sqlserver中的三种方式,具有一定的参考价值,感兴趣的小伙伴们可以参考一下 本篇, ...

  5. .net core 注入中的三种模式:Singleton、Scoped 和 Transient

    从上篇内容不如题的文章<.net core 并发下的线程安全问题>扩展认识.net core注入中的三种模式:Singleton.Scoped 和 Transient 我们都知道在 Sta ...

  6. Asp.Net中的三种分页方式

    Asp.Net中的三种分页方式 通常分页有3种方法,分别是asp.net自带的数据显示空间如GridView等自带的分页,第三方分页控件如aspnetpager,存储过程分页等. 第一种:使用Grid ...

  7. httpClient中的三种超时设置小结

    httpClient中的三种超时设置小结   本文章给大家介绍一下关于Java中httpClient中的三种超时设置小结,希望此教程能给各位朋友带来帮助. ConnectTimeoutExceptio ...

  8. MySQL buffer pool中的三种链

    三种page.三种list.LRU控制调优 一.innodb buffer pool中的三种页 1.free page:从未用过的页 2.clean page:干净的页,数据页的数据和磁盘一致 3.d ...

  9. 研究分析JS中的三种逻辑语句

    JS中的三种逻辑语句:顺序.分支和循环语句. 一.顺序语句 代码规范如下:1. <script type="text/javascript"> var a = 10;  ...

  10. JavaScript中的三种弹出对话框

    学习过js的小伙伴会发现,我们在一些实例中用到了alert()方法.prompt()方法.prompt()方法,他们都是在屏幕上弹出一个对话框,并且在上面显示括号内的内容,使用这种方法使得页面的交互性 ...

随机推荐

  1. 解读surging 的内存过高的原因

    前言 对于.NET开发人员来讲,一个程序占用内存过高,是极其糟糕,是一款不合格的程序软件,.NET开发人员也不会去使用服务器垃圾收集器(ServerGarbageCollection),而是选用工作站 ...

  2. JavaScript通过递归实现深拷贝

    思路 首先是用Object.prototype.toString.call(obj)来得到传入的值的类型,如果是几个基本类型,则直接返回值就可以了 如果是引用类型,则通过深拷贝函数递归进行再次拷贝. ...

  3. Eclipse build js卡死 Eclipse 编译太卡,耗时太长解决

    Eclipse build js卡死 Eclipse 编译太卡,耗时太长解决 问题描述:编译停止在js编译中,原来是js的问题 1.首选项-javaScript-Validator-Errors/Wa ...

  4. Service Mesh技术详解

    深入探讨Service Mesh的基本概念和核心技术,涵盖了服务发现.负载均衡.断路器与熔断机制,以及数据平面与控制平面的详细工作原理和实现方法. 关注作者,复旦博士,分享云服务领域全维度开发技术.拥 ...

  5. 小组合作实现的基于 jsp,servlet,mysql 编写的学校管理系统

    基本完成的页面--源代码在<文件>中可下载 文件地址:https://i.cnblogs.com/Files.aspx 学生管理模块各功能已实现 百度网盘下载地址: 链接:https:// ...

  6. Unity 中关于SubMesh的拾取问题

    问题背景 最近在开发一个功能,钻孔功能,每一层(段)都需要单独拾取,显示不同的颜色,使用不同材质 问题分析 对于这个功能,由于上述需求,很容易想到用submesh实现,但是主要问题是在于对于Subme ...

  7. package-lock.json 文件

    今天有同事找到我说,本地js 编译不过,编译不过的代码如下 const host = window?.location?.host || 'localhost'; 是option chaining, ...

  8. GCC8 编译优化 BUG 导致的内存泄漏

    1. 背景 1.1. 接手老系统 最近我们又接手了一套老系统,老系统的迭代效率和稳定性较差,我们打算做重构改造,但重构周期较长,在改造完成之前还有大量的需求迭代.因此我们打算先从稳定性和迭代效率出发做 ...

  9. Simple WPF: S3实现MINIO大文件上传并显示上传进度

    最新内容优先发布于个人博客:小虎技术分享站,随后逐步搬运到博客园. 创作不易,如果觉得有用请在Github上为博主点亮一颗小星星吧! 目的 早两天写了一篇S3简单上传文件的小工具,知乎上看到了一个问题 ...

  10. Netcode for Entities里如何对Ghost进行可见性筛选(1.2.3版本)

    一行代码省流:SystemAPI.GetSingleton() 当你需要按照区域.距离或者场景对Ghost进行筛选的时候,Netcode for Entities里并没有类似FishNet那样方便的过 ...