当我们在大型的数据集上面进行深度学习的训练时,往往需要大量的运行资源,而且还要花费大量时间才能完成训练。

1.分布式TensorFlow的角色与原理

在分布式的TensorFlow中的角色分配如下:

PS:作为分布式训练的服务端,等待各个终端(supervisors)来连接。

worker:在TensorFlow的代码注释中被称为终端(supervisors),作为分布式训练的计算资源终端。

chief supervisors:在众多的运算终端中必须选择一个作为主要的运算终端。该终端在运算终端中最先启动,它的功能是合并各个终端运算后的学习参数,将其保存或者载入。

每个具体的网络标识都是唯一的,即分布在不同IP的机器上(或者同一个机器的不同端口)。在实际的运行中,各个角色的网络构建部分代码必须100%的相同。三者的分工如下:

服务端作为一个多方协调者,等待各个运算终端来连接。

chief supervisors会在启动时同一管理全局的学习参数,进行初始化或者从模型载入。

其他的运算终端只是负责得到其对应的任务并进行计算,并不会保存检查点,用于TensorBoard可视化中的summary日志等任何参数信息。

在整个过程都是通过RPC协议来进行通信的。

2.分布部署TensorFlow的具体方法

配置过程中,首先建立一个server,在server中会将ps及所有worker的IP端口准备好。接着,使用tf.train.Supervisor中的managed_ssion来管理一个打开的session。session中只是负责运算,而通信协调的事情就都交给supervisor来管理了。

3.部署训练实例

下面开始实现一个分布式训练的网络模型,以线性回归为例,通过3个端口来建立3个终端,分别是一个ps,两个worker,实现TensorFlow的分布式运算。

1. 为每个角色添加IP地址和端口,创建sever,在一台机器上开3个不同的端口,分别代表PS,chief supervisor和worker。角色的名称用strjob_name表示,以ps为例,代码如下:

# 定义IP和端口
strps_hosts = 'localhost:1681'
strworker_hosts = 'localhost:1682,localhost:1683' # 定义角色名称
strjob_name = 'ps'
task_index = 0 # 将字符串转数组
ps_hosts = strps_hosts.split(',')
worker_hosts = strps_hosts.split(',') cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts}) # 创建server
server = tf.train.Server({'ps':ps_hosts, 'worker':worker_hosts}, job_name=strjob_name, task_index=task_index)

2为ps角色添加等待函数

ps角色使用server.join函数进行线程挂起,开始接受连续消息。

# ps角色使用join进行等待
if strjob_name == 'ps':
print("wait")
server.join()

3.创建网络的结构

与正常的程序不同,在创建网络结构时,使用tf.device函数将全部的节点都放在当前任务下。在tf.device函数中的任务是通过tf.train.replica_device_setter来指定的。在tf.train.replica_device_setter中使用worker_device来定义具体任务名称;使用cluster的配置来指定角色及对应的IP地址,从而实现管理整个任务下的图节点。代码如下:

with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d'%task_index,
cluster=cluster_spec)):
X = tf.placeholder('float')
Y = tf.placeholder('float')
# 模型参数
w = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.zeros([1]), name='bias') global_step = tf.train.get_or_create_global_step() # 获取迭代次数 z = tf.multiply(X, w) + b
tf.summary('z', z)
cost = tf.reduce_mean(tf.square(Y - z))
tf.summary.scalar('loss_function', cost)
learning_rate = 0.001 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step) saver = tf.train.Saver(max_to_keep=1) merged_summary_op = tf.summary.merge_all() # 合并所有summary init = tf.global_variables_initializer()

4.创建Supercisor,管理session

在tf.train.Supervisor函数中,is_chief表明为是否为chief Supervisor角色,这里将task_index=0的worker设置成chief Supervisor。saver需要将保存检查点的saver对象传入。init_op表示使用初始化变量的函数。

training_epochs = 2000
display_step = 2 sv = tf.train.Supervisor(is_chief=(task_index == 0),# 0号为chief
logdir='log/spuer/',
init_op=init,
summary_op=None,
saver=saver,
global_step=global_step,
save_model_secs=5) # 连接目标角色创建session
with sv.managed_session(saver.target) as sess:

5迭代训练

session中的内容与以前一样,直接迭代训练即可。由于使用了supervisor管理session,将使用sv.summary_computed函数来保存summary文件。

print('sess ok')
print(global_step.eval(session=sess)) for epoch in range(global_step.eval(session=sess), training_epochs*len(train_x)):
for (x, y) in zip(train_x, train_y):
_, epoch = sess.run([optimizer, global_step], feed_dict={X: x, Y: y})
summary_str = sess.run(merged_summary_op, feed_dict={X: x, Y: y})
sv.summary_computed(sess, summary_str, global_step=epoch)
if epoch % display_step == 0:
loss = sess.run(cost, feed_dict={X:train_x, Y:train_y})
print("Epoch:", epoch+1, 'loss:', loss, 'W=', sess.run(w), w, 'b=', sess.run(b)) print(' finished ')
sv.saver.save(sess, 'log/linear/' + "sv.cpk", global_step=epoch) sv.stop()

(1)在设置自动保存检查点文件后,手动保存仍然有效,

(2)在运行一半后,在运行supervisor时会自动载入模型的参数,不需要手动调用restore。

(3)在session中不需要进行初始化的操作。

6.建立worker文件

新建两个py文件,设置task_index分别为0和1,其他的部分和上述的代码相一致。

strjob_name = 'worker'
task_index = 1 strjob_name = 'worker'
task_index = 0

7.运行

我们分别启动写好的三个文件,在运行结果中,我们可以看到循环的次数不是连续的,显示结果中会有警告,这是因为在构建supervisor时没有填写local_init_op参数,该参数的含义是在创建worker实例时,初始化本地变量,上述代码中没有设置,系统会自动初始化,并给出警告提示。

分布运算的目的是为了提高整体运算速度,如果同步epoch的准确率需要牺牲总体运行速度为代价,自然很不合适。

在ps的文件中,它只是负责连接,并不参与运算。

TensorFlow——分布式的TensorFlow运行环境的更多相关文章

  1. [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

    [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 目录 [源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇 1. ...

  2. [源码解析] TensorFlow 分布式之 MirroredStrategy

    [源码解析] TensorFlow 分布式之 MirroredStrategy 目录 [源码解析] TensorFlow 分布式之 MirroredStrategy 1. 设计&思路 1.1 ...

  3. [源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算

    [源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算 目录 [源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算 0x1. 运行 ...

  4. [源码解析] TensorFlow 分布式之 ClusterCoordinator

    [源码解析] TensorFlow 分布式之 ClusterCoordinator 目录 [源码解析] TensorFlow 分布式之 ClusterCoordinator 1. 思路 1.1 使用 ...

  5. [源码解析] TensorFlow 分布式环境(1) --- 总体架构

    [源码解析] TensorFlow 分布式环境(1) --- 总体架构 目录 [源码解析] TensorFlow 分布式环境(1) --- 总体架构 1. 总体架构 1.1 集群角度 1.1.1 概念 ...

  6. [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

    [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑 1. 总述 2. 接口 2.1 ...

  7. [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

    [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 目录 [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑 1. 继承关系 1.1 角 ...

  8. [源码解析] TensorFlow 分布式环境(4) --- WorkerCache

    [源码解析] TensorFlow 分布式环境(4) --- WorkerCache 目录 [源码解析] TensorFlow 分布式环境(4) --- WorkerCache 1. WorkerCa ...

  9. [源码解析] TensorFlow 分布式环境(5) --- Session

    [源码解析] TensorFlow 分布式环境(5) --- Session 目录 [源码解析] TensorFlow 分布式环境(5) --- Session 1. 概述 1.1 Session 分 ...

随机推荐

  1. IOS 监控网络变化案例源码

    随着移动网络升级:2G->3G->4G甚至相传正在研发的5G,网络速度是越来越快,但这流量也像流水一般哗哗的溜走. 网上不是流传一个段子:睡觉忘记关流量,第二天房子就归移动了! 这固然是一 ...

  2. Percona Xtrabackup对数据库进行部分备份

    Xtrabackup也可以实现部分备份,即只备份某个或某些指定的数据库或某数据库中的某个或某些表.但要使用此功能,必须启用innodb_file_per_table选项,即每张表保存为一个独立的文件. ...

  3. PHP获得文件的大小并转换格式

    利用filesize($filename)函数获得一个文件的大小 参数$filename为文件的绝对路径,返回的值是文件的大小字节数. 文件较大的时候看起来不方便,下面是一个格式化方法 functio ...

  4. [iOS Reverse]logify日志追踪,锁定注入口-控制台查看

    前言 logify是theos的一个组件,路径是: /opt/theos/bin/logify.pl 我们还是以微信红包为例子,根据[iOS Hacking]运行时分析cycript得到的入口文件: ...

  5. java 循环document 通用替换某个字符串或特殊字符

    document 生成xml时 报错 XML-20100: (Fatal Error) Expected ';'.  查了半天发现是 特殊字符 & 不能直接转出,需要进行转换,因为是通用方法很 ...

  6. sessionStorage和localStorage存储的转换不了json

    先说说localStorage与sessionStorage的差别 sessionStorage是存储浏览器的暂时性的数据,当关闭浏览器下次再打开的时候就不能拿到之前存储的缓存了 localStora ...

  7. 定位IO瓶颈的方法,iowait低,IO就没有到瓶颈?

    通过分析mpstat的iowait和iostat的util%,判断IO瓶颈 IO瓶颈往往是我们可能会忽略的地方(我们常会看top.free.netstat等等,但经常会忽略IO的负载情况),今天给大家 ...

  8. 洛谷P1012 拼数【字符串+排序】

    设有nn个正整数(n≤20)(n≤20),将它们联接成一排,组成一个最大的多位整数. 例如:n=3n=3时,33个整数1313,312312,343343联接成的最大整数为:3433121334331 ...

  9. str使用注意点(未完待续)

    1.(切片)str[num1,num2],是指截取 str[num1] 到 str[num2-1] 直接的元素

  10. poj 2763(LCA + dfs序 +树状数组)

    算是模板题了 可以用dfs序维护点到根的距离 注意些LCA的时候遇到MAXM,要-1 #include<cstdio> #include<algorithm> #include ...