Tensorflow中的图(tf.Graph)和会话(tf.Session)详解
Tensorflow中的图(tf.Graph)和会话(tf.Session)
Tensorflow编程系统

Tensorflow工具或者说深度学习本身就是一个连贯紧密的系统。一般的系统是一个自治独立的、能实现复杂功能的整体。系统的主要任务是对输入进行处理,以得到想要的输出结果。我们之前见过的很多系统都是线性的,就像汽车生产工厂的流水线一样,输入->系统处理->输出。系统内部由很多单一的基本部件构成,这些单一部件具有特定的功能,且需要稳定的特性;系统设计者通过特殊的连接方式,让这些简单部件进行连接,以使它们之间可以进行数据交流和信息互换,来达到相互配合而完成具体工作的目的。
对于任何一个系统来说,都应该拥有稳定、独立、能处理特殊任务的单一部件;且拥有一套良好的内部沟通机制,以让系统可以健康安全的运行。
现实中的很多系统都是线性的,被设计好的、不能进行更改的,比如工厂的流水线,这样的系统并不具备自我调整的能力,无法对外界的环境做出反应,因此也就不具备“智能”。
深度学习(神经网络)之所以具备智能,就是因为它具有反馈机制。深度学习具有一套对输出所做的评价函数(损失函数),损失函数在对神经网络做出评价后,会通过某种方式(梯度下降法)更新网络的组成参数,以期望系统得到更好的输出数据。

由此可见,神经网络的系统主要由以下几个方面组成:
- 输入
- 系统本身(神经网络结构),以及涉及到系统本身构建的问题:如网络构建方式、网络执行方式、变量维护、模型存储和恢复等等问题
- 损失函数
- 反馈方式:训练方式
定义好以上的组成部分,我们就可以用流程化的方式将其组合起来,让系统对输入进行学习,调整参数。因为该系统的反馈机制,所以,组成的方式肯定需要循环。
而对于Tensorflow来说,其设计理念肯定离不开神经网络本身。所以,学习Tensorflow之前,对神经网络有一个整体、深刻的理解也是必须的。如下图:Tensorflow的执行示意。

那么对于以上所列的几点,什么才是最重要的呢?我想肯定是有关系统本身所涉及到的问题。即如何构建、执行一个神经网络? 在Tensorflow中,用计算图来构建网络,用会话来具体执行网络。深入理解了这两点,我想,对于Tensorflow的设计思路,以及运行机制,也就略知一二了。
- 图(tf.Graph):计算图,主要用于构建网络,本身不进行任何实际的计算。计算图的设计启发是高等数学里面的链式求导法则的图。我们可以将计算图理解为是一个计算模板或者计划书。
会话(tf.session):会话,主要用于执行网络。所有关于神经网络的计算都在这里进行,它执行的依据是计算图或者计算图的一部分,同时,会话也会负责分配计算资源和变量存放,以及维护执行过程中的变量。
接下来,我们主要从计算图开始,看一看Tensorflow是如何构建、执行网络的。
计算图
在开始之前,我们先复习一下Tensorflow的几种基本数据类型:
tf.constant(value, dtype=None, shape=None, name='Const', verify_shape=False)
tf.Variable(initializer, name)
tf.placeholder(dtype, shape=None, name=None)
复习完毕。
graph = tf.Graph()
with graph.as_default():
img = tf.constant(1.0, shape=[1,5,5,3])
以上代码中定义了一个计算图,在该计算图中定义了一个常量。Tensorflow默认会创建一张计算图。所以上面代码中的前两行,可以省略。默认情况下,计算图是空的。

在执行完img = tf.constant(1.0, shape=[1,5,5,3])以后,计算图中生成了一个node,一个node结点由name, op, input, attrs组成,即结点名称、操作、输入以及一系列的属性(类型、形状、值)等组成,计算图就是由这样一个个的node组成的。对于tf.constant()函数,只会生成一个node,但对于有的函数,如tf.Variable(initializer, name)(注意其第一个参数是初始化器)就会生成多个node结点(后面会讲到)。
那么执行完img = tf.constant(1.0, shape=[1,5,5,3])后,计算图中就多一个node结点。(因为每个node的属性很多,我只表示name,op,input属性)

继续添加代码:
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
代码执行后的计算图如下:

需要注意的是,如果没有对结点进行命名,Tensorflow自动会将其命名为:Const、Const_1、const_2......。其他类型的结点类同。
现在,我们添加一个变量:
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
kernel = tf.Variable(k)
该变量用一个常量作为初始化器。我们先看一下计算图:

如图所示:
执行完tf.Variable()函数后,一共产生了三个结点:
- Variable:变量维护(不存放实际的值)
- Variable/Assign:变量分配
- Variable/read:变量使用
图中只是完成了操作的定义,但并没有执行操作(如Variable/Assign结点的Assign操作,所以,此时候变量依然不可以使用,这就是为什么要在会话中初始化的原因)。
我们继续添加代码:
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
kernel = tf.Variable(k)
y = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME")
得到的计算图如下:

可以看出,变量读取是通过Variable/read来进行的。
如果在这里我们直接开启会话,并执行计算图中的卷积操作,系统就会报错。
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
kernel = tf.Variable(k)
y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME")
with tf.Session() as sess:
sess.run(y2)
这段代码错误的原因在于,变量并没有初始化就被使用,而从图中清晰的可以看到,直接执行卷积,是回溯不到变量的值(Const_1)的(箭头方向)。
所以,在执行之前,要进行初始化,代码如下:
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
kernel = tf.Variable(k)
y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME")
init = tf.global_variables_initializer()
执行完tf.global_variables_initializer()函数以后,计算图如下:

tf.global_variables_initializer()产生了一个名为init的node,该结点将所有的Variable/Assign结点作为输入,以达到对整张计算图中的变量进行初始化。
所以,在开启会话后,执行的第一步操作,就是变量初始化(当然变量初始化的方式有很多种,我们也可以显示调用tf.assign()来完成对单个结点的初始化)。
完整代码如下:
img = tf.constant(1.0, shape=[1,5,5,3])
k = tf.constant(1.0, shape=[3,3,3,1])
kernel = tf.Variable(k)
y2 = tf.nn.conv2d(img, kernel, strides=[1,2,2,1], padding="SAME")
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# do someting....
会话
在上述代码中,我已经使用会话(tf.session())来执行计算图了。在tf.session()中,我们重点掌握无所不能的sess.run()。
一个session()中包含了Operation被执行,以及Tensor被evaluated的环境。
tf.Session().run()函数的定义:
run(
fetches,
feed_dict=None,
options=None,
run_metadata=None
)
tf.Session().run()函数的功能为:执行fetches参数所提供的operation操作或计算其所提供的Tensor。
run()函数每执行一步,都会执行与fetches有关的图中的所有结点的计算,以完成fetches中的任务。其中,feed_dict提供了部分数据输入的功能。(和tf.Placeholder()搭配使用,很舒服)
参数说明:
- fetches:可以是图中的一个结点,也可以是一个List或者字典,此时候返回值与fetches格式一致;该参数还可以是一个Operation,此时候返回值为None。
- feed_dict:字典格式。给模型输入其计算过程中所需要的值。
当我们把模型的计算图构建好以后,就可以利用会话来进行执行训练了。
在明白了计算图是如何构建的,以及如何被会话正确的执行以后,我们就可以愉快的开始Tensorflow之旅啦。
参考链接
Tensorflow中的图(tf.Graph)和会话(tf.Session)详解的更多相关文章
- tensorflow中的图(02-1)
由于tensorflow版本迭代较快且不同版本的接口会有差距,我这里使用的是1.14.0的版本 安装指定版本的方法:pip install tensorflow==1.14.0 如果你之前安 ...
- 超详细的Tensorflow模型的保存和加载(理论与实战详解)
1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...
- 网络基础 http 会话(session)详解
http 会话(session)详解 by:授客 QQ:1033553122 会话(session)是一种持久网络协议,在用户(或用户代理)端和服务器端之间创建关联,从而起到交换数据包的作用机制 一. ...
- 巨人大哥谈Web应用中的Session(session详解)
巨人大哥谈Web应用中的Session(session详解) 虽然session机制在web应用程序中被采用已经很长时间了,但是仍然有很多人不清楚session机制的本质,以至不能正确的应用这一技术. ...
- PHP中IP地址与整型数字互相转换详解
这篇文章主要介绍了PHP中IP地址与整型数字互相转换详解,本文介绍了使用PHP函数ip2long与long2ip的使用,以及它们的BUG介绍,最后给出自己写的两个算法,需要的朋友可以参考下 IP转换成 ...
- ArcGIS中的北京54和西安80投影坐标系详解
ArcGIS中的北京54和西安80投影坐标系详解 1.首先理解地理坐标系(Geographic coordinate system),Geographic coordinate system直译为地理 ...
- [转]js中几种实用的跨域方法原理详解
转自:js中几种实用的跨域方法原理详解 - 无双 - 博客园 // // 这里说的js跨域是指通过js在不同的域之间进行数据传输或通信,比如用ajax向一个不同的域请求数据,或者通过js获取页面中不同 ...
- Nginx服务器中配置非80端口的端口转发方法详解
这篇文章主要介绍了Nginx服务器中配置非80端口的端口转发方法详解,文中使用到了Nginx中的proxy_pass配置项,需要的朋友可以参考下 nginx可以很方便的配置成反向代理服务器: 1 2 ...
- java使用POI操作XWPFDocument中的XWPFRun(文本)对象的属性详解
java使用POI操作XWPFDocument中的XWPFRun(文本)对象的属性详解 我用的是office word 2016版 XWPFRun是XWPFDocument中的一段文本对象(就是一段文 ...
随机推荐
- Linux用户登出之后保持后台进程(nohup)
使用&可以将进程置于后台,但是用户从Shell登出之后,进程会自动结束.想要在登出之后保持进程运行,就要结合nohup命令使用. 例如: nohup find -size +100k > ...
- C#基础提升系列——C#委托
C# 委托 委托是类型安全的类,它定义了返回类型和参数的类型,委托类可以包含一个或多个方法的引用.可以使用lambda表达式实现参数是委托类型的方法. 委托 当需要把一个方法作为参数传递给另一个方法时 ...
- 【NLP新闻-2013.06.16】Representative Reviewing
英语原文地址:http://nlp.hivefire.com/articles/share/40221/ 注:本人翻译NLP新闻只为学习专业英语和扩展视野,如果翻译的不好,请谅解! (实在是读不大懂, ...
- python不能运行
运行python文件出现,报错please select a valid interpreter 这是因为没有选择interpreter 就是更改目录时需要重新选择pytho解析器 解决方法如下 更 ...
- php rand()函数 语法
php rand()函数 语法 rand()函数怎么用? php rand()函数表示从参数范围内得到一个随机数,语法是rand(X,Y),从两个参数范围内得到一个随机数,随机数大于等于X或者小于等于 ...
- JS中数据结构之二叉查找树
树是一种非线性的数据结构,以分层的方式存储数据.在二叉树上进行查找非常快,为二叉树添加或删除元素也非常快. 一棵树最上面的节点称为根节点,如果一个节点下面连接多个节点,那么该节点称为父节点,它下面的节 ...
- java使用开源类库Tesseract实现图片识别
Tesseract-OCR支持中文识别,并且开源和提供全套的训练工具,是快速低成本开发的首选. Tess4J则是Tesseract在Java PC上的应用 Tesseract的OCR引擎最先由HP实验 ...
- 前端每日实战:23# 视频演示如何用纯 CSS 创作一个菜单反色填充特效
效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览.https://codepen.io/comehope/pen/qYMoPo 可交互视频教程 此视频是 ...
- IntelliJ IDEA更新maven依赖包
问题: IntelliJ IDEA自动载入Maven依赖的功能很好用,但有时候会碰到问题,导致pom文件修改却没有触发自动重新载入的动作,此时需要手动强制更新依赖. 方法: 方法一: ①.右键单击项目 ...
- HDU4762 Cut the Cake
HDU4762 Cut the Cake 思路:公式:n/m(n-1) //package acm; import java.awt.Container; import java.awt.geom.A ...