Tensorflow学习笔记4:分布式Tensorflow
简介
Tensorflow API提供了Cluster、Server以及Supervisor来支持模型的分布式训练。
关于Tensorflow的分布式训练介绍可以参考Distributed Tensorflow。简单的概括说明如下:
- Tensorflow分布式Cluster由多个Task组成,每个Task对应一个tf.train.Server实例,作为Cluster的一个单独节点;
- 多个相同作用的Task可以被划分为一个job,例如ps job作为参数服务器只保存Tensorflow model的参数,而worker job则作为计算节点只执行计算密集型的Graph计算。
- Cluster中的Task会相对进行通信,以便进行状态同步、参数更新等操作。
Tensorflow分布式集群的所有节点执行的代码是相同的。分布式任务代码具有固定的模式:
# 第1步:命令行参数解析,获取集群的信息ps_hosts和worker_hosts,以及当前节点的角色信息job_name和task_index # 第2步:创建当前task结点的Server
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 第3步:如果当前节点是ps,则调用server.join()无休止等待;如果是worker,则执行第4步。
if FLAGS.job_name == "ps":
server.join() # 第4步:则构建要训练的模型
# build tensorflow graph model # 第5步:创建tf.train.Supervisor来管理模型的训练过程
# Create a "supervisor", which oversees the training process.
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs")
# The supervisor takes care of session initialization and restoring from a checkpoint.
sess = sv.prepare_or_wait_for_session(server.target)
# Loop until the supervisor shuts down
while not sv.should_stop()
# train model
Tensorflow分布式训练代码框架
根据上面说到的Tensorflow分布式训练代码固定模式,如果要编写一个分布式的Tensorlfow代码,其框架如下所示。
import tensorflow as tf # Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
"Comma-separated list of hostname:port pairs") # Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") FLAGS = tf.app.flags.FLAGS def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts(",") # Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index) if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)): # Build model...
loss = ...
global_step = tf.Variable(0) train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step) saver = tf.train.Saver()
summary_op = tf.merge_all_summaries()
init_op = tf.initialize_all_variables() # Create a "supervisor", which oversees the training process.
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="/tmp/train_logs",
init_op=init_op,
summary_op=summary_op,
saver=saver,
global_step=global_step,
save_model_secs=600) # The supervisor takes care of session initialization and restoring from
# a checkpoint.
sess = sv.prepare_or_wait_for_session(server.target) # Start queue runners for the input pipelines (if any).
sv.start_queue_runners(sess) # Loop until the supervisor shuts down (or 1000000 steps have completed).
step = 0
while not sv.should_stop() and step < 1000000:
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
_, step = sess.run([train_op, global_step]) if __name__ == "__main__":
tf.app.run()
对于所有Tensorflow分布式代码,可变的只有两点:
- 构建tensorflow graph模型代码;
- 每一步执行训练的代码
分布式MNIST任务
我们通过修改tensorflow/tensorflow提供的mnist_softmax.py来构造分布式的MNIST样例来进行验证。修改后的代码请参考mnist_dist.py。
我们同样通过tensorlfow的Docker image来启动一个容器来进行验证。
$ docker run -d -v /path/to/your/code:/tensorflow/mnist --name tensorflow tensorflow/tensorflow
启动tensorflow之后,启动4个Terminal,然后通过下面命令进入tensorflow容器,切换到/tensorflow/mnist目录下
$ docker exec -ti tensorflow /bin/bash
$ cd /tensorflow/mnist
然后在四个Terminal中分别执行下面一个命令来启动Tensorflow cluster的一个task节点,
# Start ps
python mnist_dist.py --ps_hosts=localhost:,localhost: --worker_hosts=localhost:,localhost: --job_name=ps --task_index= # Start ps
python mnist_dist.py --ps_hosts=localhost:,localhost: --worker_hosts=localhost:,localhost: --job_name=ps --task_index= # Start worker
python mnist_dist.py --ps_hosts=localhost:,localhost: --worker_hosts=localhost:,localhost: --job_name=worker --task_index= # Start worker
python mnist_dist.py --ps_hosts=localhost:,localhost: --worker_hosts=localhost:,localhost: --job_name=worker --task_index=
具体效果自己验证哈。
Tensorflow学习笔记4:分布式Tensorflow的更多相关文章
- TensorFlow学习笔记0-安装TensorFlow环境
TensorFlow学习笔记0-安装TensorFlow环境 作者: YunYuan 转载请注明来源,谢谢! 写在前面 系统: Windows Enterprise 10 x64 CPU:Intel( ...
- 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践
分布式TensorFlow由高性能gRPC库底层技术支持.Martin Abadi.Ashish Agarwal.Paul Barham论文<TensorFlow:Large-Scale Mac ...
- 【学习笔记】分布式Tensorflow
目录 分布式原理 单机多卡 多机多卡(分布式) 分布式的架构 节点之间的关系 分布式的模式 数据并行 同步更新和异步更新 分布式API 分布式案例 Tensorflow的一个特色就是分布式计算.分布式 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
- TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]
I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
随机推荐
- MongoDB学习笔记——Master/Slave主从复制
Master/Slave主从复制 主从复制MongoDB中比较常用的一种方式,如果要实现主从复制至少应该有两个MongoDB实例,一个作为主节点负责客户端请求,另一个作为从节点负责从主节点映射数据,提 ...
- 曲演杂坛--使用CTE时踩的小坑:No Join Predicate
在一次系统优化中,意外发现一个比较“坑”的SQL,拿出来供大家分享. 生成演示数据: --====================================== --检查测试表是否存在 IF(O ...
- 烂泥:SQL Server 2005数据库安装
本文由秀依林枫提供友情赞助,首发于烂泥行天下. 为了能更好的利用服务器,所以打算把该业务进行迁移.因为该业务比较特殊,需要服务器上有相应的硬件支持,所以打算直接升级该服务器目前的操作系统.目前公司服务 ...
- Android横竖屏切换重载问题与小结
(转自:http://www.cnblogs.com/franksunny/p/3714442.html) (老样子,图片啥的详细文档,可以下载后观看 http://files.cnblogs.com ...
- 【原】css实现两端对齐的3种方法
说到两端对齐,大家并不陌生,在word.powerpoint.outlook等界面导航处,其实都有一个两端对齐(分散对齐)的按钮,平时使用的也不多,我们更习惯与左对齐.居中对齐.右对齐的方式来对齐页面 ...
- JAVA中关于并发的一些理解
一,JAVA线程是如何实现的? 同步,涉及到多线程操作,那在JAVA中线程是如何实现的呢? 操作系统中讲到,线程的实现(线程模型)主要有三种方式: ①使用内核线程实现 ②使用用户线程实现 ③使用用户线 ...
- 第9章 用内核对象进行线程同步(2)_可等待计时器(WaitableTimer)
9.4 可等待的计时器内核对象——某个指定的时间或每隔一段时间触发一次 (1)创建可等待计时器:CreateWaitableTimer(使用时应把常量_WIN32_WINNT定义为0x0400) 参数 ...
- MySQL 5.7.x 配置教程
软件环境 操作系统:windows 10 x64 企业版 MySQL:mysql-5.7.11-winx64 MySQL官网下载:http://downloads.mysql.com/archives ...
- ANE接入平台心得记录(安卓)
开发环境:FlashBuilder4.7 AIR13.0 Eclipse 由于我懒得陪安卓的开发环境所以我下载了包含安卓SDK Manager的Eclipse,其实直接用FlashBuilder开发A ...
- Quartz集群
为什么选择Quartz: 1)资历够老,创立于1998年,比struts1还早,但是一直在更新(27 April 2012: Quartz 2.1.5 Released),文档齐全. 2)完全由Jav ...