CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow
将深度学习模型的训练从单GPU扩展到多GPU主要面临以下问题:(1)训练框架必须支持GPU间的通信,(2)用户必须更改大量代码以使用多GPU进行训练。为了克服这些问题,本文提出了Horovod,它通过Ring Allreduce实现高效的GPU间通信,而且仅仅更改少量代码就可以实现多GPU训练。
TensorFlow中提供了一些分布式训练的API,这些API适用于不同的环境。这就导致用户往往不知道如何更改代码以进行分布式训练,而且debug也很困难。再者,TensorFlow的分布式训练性能与理想的性能相差甚远,尤其是在大规模GPU环境下。如图1所示,随着GPU数量的增加,分布式TensorFlow的吞吐量与理想的吞吐量的差距逐渐增加,加速比逐渐降低。

因为目前单GPU可以容纳大部分深度学习模型,所以本文主要针对数据并行进行优化。首先来看一下数据并行的训练过程:
- 运行多个模型副本
(a) 读取一部分数据
(b) 把数据喂给模型,进行前向传播
(c) 反向传播,计算梯度 - 将多个模型的梯度进行平均
- 更新模型
- 重复上述步骤直到模型收敛

在标准的TensorFlow中,分布式训练使用参数服务器架构,如图3所示。在参数服务器架构中,主要有worker和server两种角色。worker负责处理数据,计算梯度然后把梯度传给server;server负责聚合梯度,更新模型,然后把模型传回worker。

在这上述两种模式下,主要有以下两个挑战:
- 如何确定worker和server的数量。如果只使用1台server,那么这台server可能成为计算和网络瓶颈;如果使用多台server,那么通信模式就类似于all-to-all,这样就不能完整利用网络带宽。
- 处理愈加复杂的TensorFlow程序。在TensorFlow中,必须显式地启动worker和server,传递一堆参数然后更新代码,这就使得分布式训练变得非常繁琐复杂。
所幸的是,2017年百度提出了一种名为Ring Allreduce的算法。在该算法中,所有worker组成一个环,每台worker只和相邻的两台worker通信,如图4所示。

在Ring Allreduce中,如果有\(N\)个节点,那么每个节点会通信\(2\times (N -1)\)次:前\(N-1\)次接收值并把它加到对应的buffer中,后\(N-1\)次接收并替换对应buffer中的值。Ring Allreduce算法是带宽最优的,也就是说,当buffer足够大时,它会最大限度地利用网络带宽。
综上所述,本文取长补短,使用Ring Allreduce算法优化TensorFlow的分布式训练过程。本文的实现流程如下:
- 将代码转换成独立的Python包,名为Horovod
- 将百度的Ring Allreduce实现替换为NCCL
- 增加了对单机多GPU训练的支持
- 根据反馈更新了部分API,还实现了一个广播操作,以在所有worker上进行强制一致性初始化
import tensorflow as tf
import horovod.tensorflow as hvd
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
# Build model...
loss = ...
opt = tf.train.AdagradOptimizer(0.01)
# Add Horovod Distributed Optimizer
opt = hvd.DistributedOptimizer(opt)
# Add hook to broadcast variables from rank 0 to all other process
# during initialization.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
train_op = opt.minimize(loss)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing
# when done or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint="/tmp/train_logs",
config=config, hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Perform synchronous training
mon_sess.run(trian_op)
此外,Horovod还提供了一个名为Timeline的分析工具,它可以让用户每个节点在每次迭代时做了什么,效果如图5所示。

使用Timeline对一些模型进行分析后,发现当张量较小时,Ring Allreduce的效率并不高。因此,本文提出一种名为张量融合的技术来解决上述问题。
- 检测哪些张量将会被规约,选择适合缓冲区并具有相同数据类型的前几个张量
- 申请张量融合所需的缓冲区(如果之前没有申请的话),默认大小为64M
- 将选择的张量拷贝到融合缓冲区
- 在融合缓冲区执行allreduce操作
- 将数据从融合缓冲区拷贝到输出张量
- 重复上述步骤直到环中没有要被规约的向量
使用Horovod之后,Inception V3和ResNet-101模型的性能提升了约88%,如图6所示。

如图7,RDMA网络并没有比传统的TCP提升多少性能,只提升了约4%。

未来的工作主要包括:
- 让MPI的安装变得更容易
- 分布式深度学习模型调参经验的收集与分享
- 增加大型模型的示例
CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow的更多相关文章
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark
BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...
- Summary on deep learning framework --- TensorFlow
Summary on deep learning framework --- TensorFlow Updated on 2018-07-22 21:28:11 1. Check failed: s ...
- Distributed Deep Learning
安利一下刘铁岩老师的<分布式机器学习>这本书 以及一个大神的blog: https://zhuanlan.zhihu.com/p/29032307 https://zhuanlan.zhi ...
- Comparing deep learning frameworks: Tensorflow, CNTK, MXNet, & Caffe
https://imaginghub.com/blog/10-a-comparison-of-four-deep-learning-frameworks-tensorflow-cntk-mxnet-a ...
- Install PaddlePaddle (Parallel Distributed Deep Learning)
Step 1: Install docker on your linux system (My linux is fedora) https://docs.docker.com/engine/inst ...
- NeurIPS 2017 | TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning
在深度神经网络的分布式训练中,梯度和参数同步时的网络开销是一个瓶颈.本文提出了一个名为TernGrad梯度量化的方法,通过将梯度三值化为\({-1, 0, 1}\)来减少通信量.此外,本文还使用逐层三 ...
- [ Deep Learning ] Keras & TensorFlow安装依赖包
OS:Mac Python:3.6 一.先安装Keras,再安装TensorFlow 1. 安装Keras Package Version---------- -------h5py 2.7.1 Ke ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
随机推荐
- spring boot 中使用swagger
一.pom.xml <dependency> <groupId>io.springfox</groupId> <artifactId>springfox ...
- List接口的实现类
(一): ArrayList 构造方法 特有的方法: LinkedList 特点: 可以调用Collections类的静态方法 synchronizedCollection转换成线程安全的
- 打开Cmd的方式与基础Dos命令
基础的Dos命令 打开Cmd的方式 开始->Windows系统->命令提示符 Win键 + R输入cmd打开控制台 在任意的文件夹下面,按住shift键+鼠标右键点击在此处打开powers ...
- 博客新手:图片URL的生成
作为一名博客小白,本人是在美化自己的博客时,发现自定义背景等操作需要提供图片的URL,而不是直接上传图片.那么什么是URL呢?我们又该如何获取它呢? 什么是URL 根据维基百科:统一资源定位符(英语: ...
- 【小记录】android下opencv的cv::dft()函数,CPU版本与opencl版本的性能相差16倍
cv::dft 相差15.9倍 cpu版本 单次调用 0.029448 毫秒 opencl版本 单次调用 0.468688 毫秒 差别仅 ...
- 华为matebook x pro蓝屏和拆机更换固态硬盘
华为老版本的笔记本电脑现在总是蓝屏. 情况 原因 我个人认为是建兴的固态硬盘的缘故. 我的笔记本几乎没用过,因为考研.如果玩游戏使用的老ThinkPad S5.matebook我这个丐版因为没有独立显 ...
- unity3d发布安卓出错plese set the package name
发布时报错 参考https://forum.unity.com/threads/where-is-package-name-setting.318839/ 参考https://answers.unit ...
- gin源码解读2-揭开gin的神秘面纱
数据如何在gin中流转 func main() { gin.SetMode(gin.DebugMode) // 设置为开发模式 router := gin.Default() _ = router.S ...
- 测试udp端口
yum -y install nc 在a机器上执行: nc -ul 1080 在b机器上执行:nc -u 服务器ip 1080 a机器可以接收到报文则代表端口正常.
- Jupyter Notebook 更改字体、字体大小、行高
(废话):今天在做实验的时候遇到了一点问题,就问了问本科的室友,结果室友推荐我使用Jupyter Notebook来写代码,以前看其他同学使用过,但是一直在用Pycharm写,需要的时候顶多是Debu ...