TensorFlow分布式并行基于gRPC通信框架,一个master负责创建Session,多个worker负责执行计算图任务。

先创建TensorFlow Cluster对象,包含一组task(每个task一台独立机器),分布式执行TensorFlow计算图。一个Cluster切分多个job,一个job是一类特定任务(parameter server ps,worker),每个job可以包含多个task。每个task创建一个server,连接到Cluster,每个task执行在不同机器。也可以一台机器执行多个task(不同GPU)。tf.train.ClusterSpec初始化Cluster对象,初始化信息是Python dict,tf.train.ClusterSpec({"ps":["192.168.233.201:2222"],"worker":["192.168.233.202:2222","192.168.233.203:2222"]}),代表一个parameter server和两个worker,分别在三个不同机器上。每个task,定义自己身份,如server=tf.train.Server(cluster,job_name="ps",task_index=0),机器job定义ps第0台机器。程序中with tf.device("/job:worker/task:7"),限定Variable存放在哪个task或机器。

TensorFlow分布式模式:In-graph replication模型并行,模型计算图不同部分放在不同机器执行;Between-graph replication数据并行,每台机器相同计算图,计算不同batch数据。异步并行,每台机器独立计算梯度,计算完更新到parameter server,不等其他机器。同步并行,等所有机器都完成梯度计算,多个梯度合成统一更新模型参数。同步并行训练,loss下降速度更快,可达到最大精度更高,同步并行速度取决最慢机器,设备速度一致,效率较高。

TensorFlow实现包含1个paramter server、2个worker分布式并行训练程序,MNIST手写数据识别任务示例。写一个完整Python文件,在不同机器不同task执行。载入依赖库。

tf.app.flags定义标记,命令行执行TensorFlow设置参数。命令行指定参数被TensorFlow读取,转flags。设定数据储存目录data_dir默认/tmp/mnist-data,隐藏节点数默认100,训练最大步数train_steps默认1000000,batch_size默认100,学习速率默认0.01。

设定是否使用同步并行标记sync_replicas默认False,命令行执行时可设True开户同步并行。设定需要累计梯度个数更新模型值默认None,代表同步并行积累多少个batch梯度再进行参数更新,设None 为worker数量,所有worker完成一个batch训练后再更新模型参数。

定义ps地址,默认192.168.233.201:2222,根据集群实际情况配置。worker地址设置192.168.233.202:2222和192.168.233.203:2222.设置job_name和task_index FLAG。

flags.FLAGS直接命名FLAGS,简化使用。设置图片尺寸IMAGE_PIXELS 28。

编写程序主函数main,input_data.read_data_sets下载读取MNIST数据集,设置one_hot编码格式。检测命令行输入参数,确保job_name和task_index两个必备参数。显示job_name和task_index,ps、worker所有地址解析成列表ps_spec、worker_spec。

计算总共worker数量,tf.train.ClusterSpec生成TensorFlow Cluster对象,传入参数ps地址信息和worker地址信息。tf.train.Server创建当前机器server,连接Cluster。如当前节点是parameter server,不进行后续操作,server.join等待worker工作。

判断当前机器是否主节点,task_index是否0。定义当前机器worker_device,格式"job:worker/task:0/gpu:0"。多台机器,每台机器1块GPU,总共需要机器数量worker。如一台机器多GPU,一个task管理多个GPU或多个task分别管理。tf.train.replica_device_setter()设置worker资源,worker_device计算资源,ps_device存储模型参数资源。replica_device_setter将模型参数部署在独立ps服务器"/job:ps/cpu:0",训练操作部署在"/job:worker/task:0/gpu:0",本机GPU。创建记录全局训练步数变量global_step。

定义神经网络模型,tf.truncated_normal初始化权重,tf.zeros初始化偏置,创建输入 placeholder,tf.nn.xw_plus_b输入矩阵乘法、偏置操作,ReLU激活函数处理,得到第一个隐层输出hid。tf.nn.xw_plus_b、tf.nn.softmax对第一层输出hid处理,得到网络最终输出y。最后计算损失cross_entropy,定义优化器Adam。

判断是否设置同步训练模式sync_replicas,如果同步模型,先获取同步更新模型参数需要副本数replicas_to_aggregate;如果没有单独设置,worker数作默认值。tf.train.SyncReplicasOptimizer创建同步训练优化器,对原有优化器扩展。传入原有优化器及其他参数(replicas_to_aggregate、total_num_replicas、replica_id),原有优化器改造为同步分布式训练版本。用普通(异步)或同步优化器优化损失cross_entropy。

同步训练模式,主节点,opt.get_chief_queue_runner创建队列执行器,opt.get_init_tokens_op创建全局参数初始化器。

生成本地参数初始化操作init_op,创建临时训练目录,tf.train_Supervisor创建分布式训练监督器,传入参数is_chief、train_dir、init_op。Supervisor管理task参与到分布式训练。

设置Session参数,allow_soft_placement设True,代表操作在指定device不能执行时转到其他device执行。

如果主节点,显示初始化Session,其他节点显示等待主节点初始化操作。执行sv.prepate_or_wait_for_session()。

如果处于同步模型主节点,sv.start_queue_runners执行队列化执行器chief_queue_runner,执行全局参数初始化器init_tokens_op。

训练过程,记录worker执行训练启动时间,初始化本地训练步数local_step,进入训练循环。每步训练,从nnist.train.next_batch读取一个batch数据,生成feed_dict,调train_step执行训练。当全局训练步数达到预设最大值,停止训练。

训练结束,展示总训练时间,在验证数据上计算预测结果损失cross_entropy,展示。

在主程序执行tf.app.run()启动main函数,全部代码保存到distributed.py文件。3台不同机器分别执行distributed.py启动3个task,每次执行distributed.py,传入job_name、task_index指定worker身份。

分别在三台机器192.168.233.201、192.168.233.202、192.168.233.203执行python distributed.py。

同步模式,加上--sync_replicas=True。global_step,异步时,全局步数是所有worker训练步数和,同步时是多少轮并行训练。

    #from __future__ import absolute_import
#from __future__ import division
#from __future__ import print_function
import math
#import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
#flags.DEFINE_boolean("download_only", False,
# "Only perform downloading of data; Do not proceed to "
# "session preparation, model definition or training")
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
#flags.DEFINE_integer("num_gpus", 2,
# "Total number of gpus for each machine."
# "If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
flags.DEFINE_integer("train_steps", 1000000,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
#flags.DEFINE_boolean(
# "existing_servers", False, "Whether servers already exists. If True, "
# "will use the worker hosts via their GRPC URLs (one client process "
# "per worker host). Otherwise, will create an in-process TensorFlow "
# "server.")
flags.DEFINE_string("ps_hosts","192.168.233.201:2222",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "192.168.233.202:2223,192.168.233.203:2224",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# if FLAGS.download_only:
# sys.exit(0)
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#Construct the cluster and start the server
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
#if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
is_chief = (FLAGS.task_index == 0) # if FLAGS.num_gpus > 0:
# if FLAGS.num_gpus < num_workers:
# raise ValueError("number of gpus is less than number of workers")
# # Avoid gpu allocation conflict: now allocate task_num -> #gpu
# # for each worker in the corresponding machine
# gpu = (FLAGS.task_index % FLAGS.num_gpus)
# worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
# elif FLAGS.num_gpus == 0:
# # Just allocate the CPU to worker server
# cpu = 0
# worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# # The device setter will automatically place Variables ops on separate
# # parameter servers (ps). The non-Variable ops will be placed on the workers.
# # The ps use CPU and workers use corresponding GPU worker_device = "/job:worker/task:%d/gpu:0" % FLAGS.task_index
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
hid_w = tf.Variable(
tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS),
name="hid_w")
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
sm_w = tf.Variable(
tf.truncated_normal(
[FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
replica_id=FLAGS.task_index,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas and is_chief:
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
init_tokens_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.task_index)
# if FLAGS.existing_servers:
# server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
# print("Using existing server at: %s" % server_grpc_url)
#
# sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
# else:
sess = sv.prepare_or_wait_for_session(server.target,
config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op
print("Starting chief queue runner and running init_tokens_op")
sv.start_queue_runners(sess, [chief_queue_runner])
sess.run(init_tokens_op)
# Perform training
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
# Training feed
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
tf.app.run()

参考资料:
《TensorFlow实战》

欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

学习笔记TF041:分布式并行的更多相关文章

  1. Hadoop学习笔记(3)——分布式环境搭建

    Hadoop学习笔记(3) ——分布式环境搭建 前面,我们已经在单机上把Hadoop运行起来了,但我们知道Hadoop支持分布式的,而它的优点就是在分布上突出的,所以我们得搭个环境模拟一下. 在这里, ...

  2. 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践

    分布式TensorFlow由高性能gRPC库底层技术支持.Martin Abadi.Ashish Agarwal.Paul Barham论文<TensorFlow:Large-Scale Mac ...

  3. Hadoop学习笔记—13.分布式集群中节点的动态添加与下架

    开篇:在本笔记系列的第一篇中,我们介绍了如何搭建伪分布与分布模式的Hadoop集群.现在,我们来了解一下在一个Hadoop分布式集群中,如何动态(不关机且正在运行的情况下)地添加一个Hadoop节点与 ...

  4. Hadoop学习笔记【分布式文件系统学习笔记】

    分布式文件系统介绍 分布式文件系统:Hadoop Distributed File System,简称HDFS. 一.HDFS简介 Hadoop分布式文件系统(HDFS)被设计成适合运行在通用硬件(c ...

  5. 【学习笔记】分布式Tensorflow

    目录 分布式原理 单机多卡 多机多卡(分布式) 分布式的架构 节点之间的关系 分布式的模式 数据并行 同步更新和异步更新 分布式API 分布式案例 Tensorflow的一个特色就是分布式计算.分布式 ...

  6. 大数据学习笔记2 - 分布式文件系统HDFS(待续)

    分布式文件系统结构 分布式文件系统是一种通过网络实现文件在多台主机上进行分布式存储的文件系统,采用C/S模式实现文件系统数据访问,目前广泛应用的分布式文件系统主要包括GFS和HDFS,后者是前者的开源 ...

  7. Redis学习笔记11--Redis分布式

    Redis-2.4.15目前没有提供集群的功能,Redis作者在博客中说将在3.0中实现集群机制.目前Redis实现集群的方法主要是采用一致性哈稀分片(Shard),将不同的key分配到不同的redi ...

  8. SRE学习笔记:分布式共识系统、Paxos协议

    最近阅读了<SRE Google运维解密>的第23章,有一些感触,记录一下. 日常工作中,我们经常需要一些服务分布式的运行.跨区域如跨城.跨洲部署运行分布式系统往往是容易的,但是如何保证各 ...

  9. 【学习笔记】分布式追踪Tracing

    在软件工程中,Tracing指使用特定的日志记录程序的执行信息,与之相近的还有两个概念,它们分别是Logging和Metrics. Logging:用于记录离散的事件,包含程序执行到某一点或某一阶段的 ...

随机推荐

  1. JProfiler - Java的性能监控工具

    简介 JProfiler是一款Java的性能监控工具.可以查看当前应用的对象.对象引用.内存.CPU使用情况.线程.线程运行情况(阻塞.等待等),同时可以查找应用内存使用得热点,即:哪个对象占用的内存 ...

  2. Selenium的简单安装和使用

    Selenium的安装 pip install selenium Selenium模块需要调用浏览器,需要配置selenium的浏览器驱动 Firefox(火狐) 下载对应版本的geckdriver. ...

  3. 【Python3之pymysql模块】

    一.什么是 PyMySQL? PyMySQL 是在 Python3.x 版本中用于连接 MySQL 服务器的一个库,Python2中则使用mysqldb. PyMySQL 遵循 Python 数据库 ...

  4. Java xml 解析

    1. XML框架结构 Java SE 6 平台提供的 XML 处理主要包括两个功能:XML 处理(JAXP,Java Architecture XML Processing)和 XML 绑定(JAXB ...

  5. 解决IE8下不支持document.getElementsByClassName的方法

    在代码前面加如下代码: if (!document.getElementsByClassName) { document.getElementsByClassName = function (clas ...

  6. css3-transition过渡属性

    transition主要是用于一个元素的一种状态到另一种状态的一个过渡的过程,不能够主动触发,必须依赖于事件,例如hover伪类选择器. 一,transition简写 transition:要过渡的属 ...

  7. Python项目实战:福布斯系列之数据采集

    1 数据采集概述 开始一个数据分析项目,首先需要做的就是get到原始数据,获得原始数据的方法有多种途径.比如: 获取数据集(dataset)文件 使用爬虫采集数据 直接获得excel.csv及其他数据 ...

  8. 【AngularJS】学习资料

    1. http://www.cnblogs.com/lcllao/tag/AngularJs/ http://www.ituring.com.cn/article/13474 http://www.a ...

  9. 【javascript】数组的操作

    一.常用操作 toString():把数组转换成一个字符串  toLocaleString():把数组转换成一个字符串  join():把数组转换成一个用符号连接的字符串  shift():将数组头部 ...

  10. [补档][Tyvj 1518]CPU监控

    [Tyvj 1518]CPU监控 题目 Bob需要一个程序来监视CPU使用率.这是一个很繁琐的过程,为了让问题更加简单,Bob会慢慢列出今天会在用计算机时做什么事. Bob会干很多事,除了跑暴力程序看 ...