CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow
将深度学习模型的训练从单GPU扩展到多GPU主要面临以下问题:(1)训练框架必须支持GPU间的通信,(2)用户必须更改大量代码以使用多GPU进行训练。为了克服这些问题,本文提出了Horovod,它通过Ring Allreduce实现高效的GPU间通信,而且仅仅更改少量代码就可以实现多GPU训练。
TensorFlow中提供了一些分布式训练的API,这些API适用于不同的环境。这就导致用户往往不知道如何更改代码以进行分布式训练,而且debug也很困难。再者,TensorFlow的分布式训练性能与理想的性能相差甚远,尤其是在大规模GPU环境下。如图1所示,随着GPU数量的增加,分布式TensorFlow的吞吐量与理想的吞吐量的差距逐渐增加,加速比逐渐降低。

因为目前单GPU可以容纳大部分深度学习模型,所以本文主要针对数据并行进行优化。首先来看一下数据并行的训练过程:
- 运行多个模型副本
(a) 读取一部分数据
(b) 把数据喂给模型,进行前向传播
(c) 反向传播,计算梯度 - 将多个模型的梯度进行平均
- 更新模型
- 重复上述步骤直到模型收敛

在标准的TensorFlow中,分布式训练使用参数服务器架构,如图3所示。在参数服务器架构中,主要有worker和server两种角色。worker负责处理数据,计算梯度然后把梯度传给server;server负责聚合梯度,更新模型,然后把模型传回worker。

在这上述两种模式下,主要有以下两个挑战:
- 如何确定worker和server的数量。如果只使用1台server,那么这台server可能成为计算和网络瓶颈;如果使用多台server,那么通信模式就类似于all-to-all,这样就不能完整利用网络带宽。
- 处理愈加复杂的TensorFlow程序。在TensorFlow中,必须显式地启动worker和server,传递一堆参数然后更新代码,这就使得分布式训练变得非常繁琐复杂。
所幸的是,2017年百度提出了一种名为Ring Allreduce的算法。在该算法中,所有worker组成一个环,每台worker只和相邻的两台worker通信,如图4所示。

在Ring Allreduce中,如果有\(N\)个节点,那么每个节点会通信\(2\times (N -1)\)次:前\(N-1\)次接收值并把它加到对应的buffer中,后\(N-1\)次接收并替换对应buffer中的值。Ring Allreduce算法是带宽最优的,也就是说,当buffer足够大时,它会最大限度地利用网络带宽。
综上所述,本文取长补短,使用Ring Allreduce算法优化TensorFlow的分布式训练过程。本文的实现流程如下:
- 将代码转换成独立的Python包,名为Horovod
- 将百度的Ring Allreduce实现替换为NCCL
- 增加了对单机多GPU训练的支持
- 根据反馈更新了部分API,还实现了一个广播操作,以在所有worker上进行强制一致性初始化
import tensorflow as tf
import horovod.tensorflow as hvd
# Initialize Horovod
hvd.init()
# Pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
# Build model...
loss = ...
opt = tf.train.AdagradOptimizer(0.01)
# Add Horovod Distributed Optimizer
opt = hvd.DistributedOptimizer(opt)
# Add hook to broadcast variables from rank 0 to all other process
# during initialization.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
train_op = opt.minimize(loss)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing
# when done or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint="/tmp/train_logs",
config=config, hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Perform synchronous training
mon_sess.run(trian_op)
此外,Horovod还提供了一个名为Timeline的分析工具,它可以让用户每个节点在每次迭代时做了什么,效果如图5所示。

使用Timeline对一些模型进行分析后,发现当张量较小时,Ring Allreduce的效率并不高。因此,本文提出一种名为张量融合的技术来解决上述问题。
- 检测哪些张量将会被规约,选择适合缓冲区并具有相同数据类型的前几个张量
- 申请张量融合所需的缓冲区(如果之前没有申请的话),默认大小为64M
- 将选择的张量拷贝到融合缓冲区
- 在融合缓冲区执行allreduce操作
- 将数据从融合缓冲区拷贝到输出张量
- 重复上述步骤直到环中没有要被规约的向量
使用Horovod之后,Inception V3和ResNet-101模型的性能提升了约88%,如图6所示。

如图7,RDMA网络并没有比传统的TCP提升多少性能,只提升了约4%。

未来的工作主要包括:
- 让MPI的安装变得更容易
- 分布式深度学习模型调参经验的收集与分享
- 增加大型模型的示例
CoRR 2018 | Horovod: Fast and Easy Distributed Deep Learning in Tensorflow的更多相关文章
- (转)分布式深度学习系统构建 简介 Distributed Deep Learning
HOME ABOUT CONTACT SUBSCRIBE VIA RSS DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...
- 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark
BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...
- Summary on deep learning framework --- TensorFlow
Summary on deep learning framework --- TensorFlow Updated on 2018-07-22 21:28:11 1. Check failed: s ...
- Distributed Deep Learning
安利一下刘铁岩老师的<分布式机器学习>这本书 以及一个大神的blog: https://zhuanlan.zhihu.com/p/29032307 https://zhuanlan.zhi ...
- Comparing deep learning frameworks: Tensorflow, CNTK, MXNet, & Caffe
https://imaginghub.com/blog/10-a-comparison-of-four-deep-learning-frameworks-tensorflow-cntk-mxnet-a ...
- Install PaddlePaddle (Parallel Distributed Deep Learning)
Step 1: Install docker on your linux system (My linux is fedora) https://docs.docker.com/engine/inst ...
- NeurIPS 2017 | TernGrad: Ternary Gradients to Reduce Communication in Distributed Deep Learning
在深度神经网络的分布式训练中,梯度和参数同步时的网络开销是一个瓶颈.本文提出了一个名为TernGrad梯度量化的方法,通过将梯度三值化为\({-1, 0, 1}\)来减少通信量.此外,本文还使用逐层三 ...
- [ Deep Learning ] Keras & TensorFlow安装依赖包
OS:Mac Python:3.6 一.先安装Keras,再安装TensorFlow 1. 安装Keras Package Version---------- -------h5py 2.7.1 Ke ...
- 【深度学习Deep Learning】资料大全
最近在学深度学习相关的东西,在网上搜集到了一些不错的资料,现在汇总一下: Free Online Books by Yoshua Bengio, Ian Goodfellow and Aaron C ...
随机推荐
- 大数据安全与RANGER学习和使用
概述 再说ranger之前需要明白一下大数据的安全体系的整体介绍,安全体系其实也就是权限可控,先说说权限:权限管理的目标,绝对不是简单的在技术层面建立起用户,密码和权限点的映射关系这么简单的事,更重要 ...
- nmap高级用法
nmap在信息收集中起着很大的作用,今天我来总结一些nmap常用的一些命令 常用探测主机存活方式 1.-sP:进行ping扫描 打印出对ping扫描做出响应的主机,不做进一步测试(如端口扫描或者操作系 ...
- T-SQL创建数据库常用方法2020年10月29日20:12:04网课笔记
2.接口的作用 第一.方便框架的设计.利于团队的开发. 第二.方便项目拓展.高内聚.低耦合. 3.反射 [1]反射的理解:通过读取程序集的信息,找到相关的类型和类型的成员,也可以得到相关的对象.而这种 ...
- [USB波形分析] 全速USB波形数据分析(三)
前面的两篇文章介绍和分析了USB的一些基本知识,结合前面的介绍,今天用实例介绍USB的枚举过程. 1 | 概况 硬件基于EK-TMC123GXL开发板,软件是TI提供的USB批量传输的简单例子,在PC ...
- 【数学】立个flag
想写一个叫做<机器学习中的数学基础>系列文章
- 【解决了一个小问题】golang中引用一个路径较长的库,导致goland中出现"module contains a go.mod file, so major version must be compatible: should be v0 or v1, not v2"
在项目中的go.mod文件中有这样一句: require ( github.com/xxx-devops/xx1/sdk/go v2.2.3 ) 项目的编译没有问题,但是goland中出现如下提示: ...
- 【重构前端知识体系之HTML】讲讲对HTML5的一大特性——语义化的理解
[重构前端知识体系之HTML]讲讲对HTML5的一大特性--语义化的理解 引言 在讲什么是语义化之前,先看看语义化的背景. 在之前的文章中提到HTML最重要的特性,那就是标签.但是项目一大,标签多的看 ...
- python中的rpc库
基于xml的rpc调用 rpcserver.py from xmlrpc.server import SimpleXMLRPCServer # python中类的命名方式遵循驼峰命名法 # 1. 没有 ...
- Go 指针,标识符命名规范及关键字
#### Go 指针,标识符命名规范,关键字,运算符回顾了一下之前写的文章,以及考虑到后期的内容较多, 从这篇开始逐渐增加文章内容; 这篇我们主要学习一Go 中的指针,标识符关键字以及运算符##### ...
- Selenium&PhantomJS 完成爬取网络代理
Selenium模块是一套完整的Web应用程序测试系统,它包含了测试的录制(SeleniumIDE).编写及运行(Selenium Remote Control)和测试的并行处理(Selenimu G ...