MXNet源码分析 | Gluon接口分布式训练流程
本文主要基于MXNet1.6.0版本,对Gluon接口的分布式训练过程进行简要分析。
众所周知,KVStore负责MXNet分布式训练过程中参数的同步,那么它究竟是如何应用在训练中的呢?下面我们将从Gluon.Trainer
这个接口入手,逐步分析分布式训练的梯度交换以及参数同步过程。下面这段代码摘自python/mxnet/gluno/trainer.py文件,相较于源代码删除了一些多余的信息(如某些判断、注释等),以便让我们更好地专注于通信过程。
代码中的step
函数是进行梯度交换以及参数更新的主体,它首先调用_init_kvstore
去初始化kvstore,然后调用_allreduce_grads
进行梯度传输,最后调用_update
实现参数更新。
class Trainer(object):
def step(self, batch_size, ignore_stale_grad=False):
if not self._kv_initialized:
self._init_kvstore()
if self._params_to_init:
self._init_params()
self._allreduce_grads()
self._update(ignore_stale_grad)
首先,_init_kvstore
这个函数会通过用户指定的参数来调用model.py中的_create_kvstore
来初始化kvstore
以及update_kv_store
这两个变量。其中kvstore
是KVStore
类的一个实例化对象,而update_on_kvstore
是一个布尔型变量,用来判断是否在ps端更新参数。换句话说,如果该变量为True,那么模型参数的更新发生在ps端;否则,模型参数的更新发生在worker端,ps端只做梯度的聚合操作(这种情况下,paramerter server是不是就变成了gradient server?)。然而,只有在同步训练模式下,我们才能设置update_on_kvstore=False
,异步训练并不支持在worker端更新参数。在update_kv_store=True
的情况下,我们需要告诉ps端训练过程中使用的优化器是什么,因此要调用kvstore.set_optimizer
把优化器从worker端传给ps端。
from ..model import _create_kvstore
class Trainer(object):
def _init_kvstore(self):
"""Create kvstore."""
config = self._kvstore_params
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
arg_arrays)
self._distributed = 'dist' in kvstore.type if kvstore else False
if self._distributed and 'async' in kvstore.type:
update_on_kvstore = True
# raise err if user provides unsupported configs
if config['update_on_kvstore'] is False:
raise ValueError("Please set update_on_kvstore=True "
"when training in async mode.")
if config['update_on_kvstore'] is not None:
update_on_kvstore = config['update_on_kvstore'
if kvstore:
if update_on_kvstore:
# optimizer preferably needs to be set before init for multiprecision
kvstore.set_optimizer(self._optimizer)
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
else:
self._kvstore = None
self._update_on_kvstore = None
self._kv_initialized = True
完成kvstore的初始化后,gluon.Trainer会调用_allreduce_grads
来实现梯度的交换。欸,前面不是说MXNet是参数服务器架构吗,为啥为扯到Allreduce上呢?考虑update_on_kvstore=False
的情况,最开始每个worker上都只有自己的本地梯度,把梯度push到ps并进行聚合后,每个worker从ps上pull回来的都是相同的、聚合后的梯度。整个过程中的push和pull操作,是不是就很像Reduce和Broadcast(worker上的梯度“Reduce”到ps上,然后ps端“Broadcast”聚合结果给各个worker)?观察_allreduce_grads
的实现,可以发现,无论update_on_kvstore
的值是什么,gluno.Trainer都会把梯度从worker端push到ps端,只不过当update_on_kvstore=True
时,gluon.Trainer把梯度从worker上push到ps后就完事儿了;而当updata_on_kvstore=False
时,gluon.Trainer还会从ps端把梯度的聚合结果pull回来,以便进行本地的参数更新。
class Trainer(object):
def _allreduce_grads(self):
if self._kvstore:
for i, param in enumerate(self._params):
if param.grad_req != 'null':
self._kvstore.push(i, param.list_grad(), priority=-i)
if not self._update_on_kvstore:
self._kvstore.pull(i, param.list_grad(), priority=-i,
ignore_sparse=self._distributed)
gluon.Trainer._update
函数会根据update_on_kvstore
的值进行相应的参数拉取以及更新操作。在单机训练(kvstore is None
)或者分布式训练的本地更新模式(update_on_kvstore=True
)下,gluon.Trainer会使用用户设定的优化器在本地更新参数,以进行下一步的训练。在分布式训练的情况下,当我们设置update_on_kvstore=True
时,模型参数会在ps端进行更新,所以在该函数只需要将模型参数从ps端pull到本地即可。
class Trainer(object):
def _update(self, ignore_stale_grad=False):
updates = [[] for _ in self._updaters]
for i, param in enumerate(self._params):
if self._kvstore and self._update_on_kvstore:
if param._stype == 'default':
# 'row_sparse' parameters are not pulled immediately - they're pulled
# in `Block.forward`
self._kvstore.pull(i, param.list_data(), priority=-i)
continue
for upd, arr, grad in zip(updates, param.list_data(), param.list_grad()):
if not ignore_stale_grad or arr._fresh_grad:
upd.append((i, grad, arr))
arr._fresh_grad = False
if not (self._kvstore and self._update_on_kvstore):
for updater, upd in zip(self._updaters, updates):
if upd:
i, w, g = zip(*upd)
updater(i, w, g)
到这里,我们基本上就把python端的kvstore调用流程讲完了。
MXNet源码分析 | Gluon接口分布式训练流程的更多相关文章
- [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架
[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 目录 [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 0x00 摘要 0x01 架构图 ...
- [源码解析] 深度学习分布式训练框架 horovod (8) --- on spark
[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark 目录 [源码解析] 深度学习分布式训练框架 horovod (8) --- on spark 0x00 摘要 0 ...
- [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入
[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 目录 [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 0x00 摘要 0 ...
- [源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver
[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver 目录 [源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & ...
- [源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么
[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么 目录 [源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun ...
- [源码解析] 深度学习分布式训练框架 horovod (13) --- 弹性训练之 Driver
[源码解析] 深度学习分布式训练框架 horovod (13) --- 弹性训练之 Driver 目录 [源码解析] 深度学习分布式训练框架 horovod (13) --- 弹性训练之 Driver ...
- [源码解析] 深度学习分布式训练框架 horovod (14) --- 弹性训练发现节点 & State
[源码解析] 深度学习分布式训练框架 horovod (14) --- 弹性训练发现节点 & State 目录 [源码解析] 深度学习分布式训练框架 horovod (14) --- 弹性训练 ...
- [源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知
[源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知 目录 [源码解析] 深度学习分布式训练框架 horovod (15) --- 广播 & 通知 0 ...
- [源码解析] 深度学习分布式训练框架 horovod (16) --- 弹性训练之Worker生命周期
[源码解析] 深度学习分布式训练框架 horovod (16) --- 弹性训练之Worker生命周期 目录 [源码解析] 深度学习分布式训练框架 horovod (16) --- 弹性训练之Work ...
随机推荐
- POJ3090Visible Lattice Points
http://poj.org/problem?id=3090 对于此题,观测点的数目,从小规模开始观察,可以得到每一个点,由一根无限长的绳子,绕着原点旋转,得到的第一个点.换另外一个思路,每一个观察到 ...
- Python常用的数据结构
一.list 列表 1.列表的特点 有序的,元素可以重复,列表中的元素可以进行增上改查,可以存放不同的数据类型 2.创建列表 中括号创建并填充 --->[] 通过构造函数创建 list() 列表 ...
- Javascript中定时器的使用方法
Javascript中定时器的使用方法 1.间隔定时器(每隔一段时间执行一次代码) 格式:setInterval(函数,时间) //时间单位是毫秒,每隔设置的时间执行函数里的内容一遍(一直执行) // ...
- 【分享数据】vm-insert的压缩比达到29倍
vm-insert采用remote-write的http协议来接收metric数据,然后按照一定算法转发到vm-storage群集. vm-insert到vm-storage这里是用了自己的二进制协议 ...
- 【记录一个问题】android opencl c++: 不要Context, CommandQueue类的赋值函数
一开始代码中这样写了: cl::Context ctx = cl::Context(CL_DEVICE_TYPE_GPU, NULL); cl::CommandQueue queue= cl::Com ...
- Centos 7 安装LAMP以及在Apache上安装positiveSSL。
简介 LAMP(linux , Apache, mysql , php)是集成动态网站经常使用的一套开源软件,实际包含linux操作系统,Apache web服务器,mysql(mariadb 分支) ...
- LaTex用法笔记(一)——LaTex源文件的基本结构
首先打开TeXstudio,创建一个新文件并保存 用\documentclass{article}引入一个文档类,也可以引用book/report/letter 然后用\begin{}和\end{}输 ...
- Redis持久化----RDB和AOF 的区别
关于Redis说点什么,目前都是使用Redis作为数据缓存,缓存的目标主要是那些需要经常访问的数据,或计算复杂而耗时的数据.缓存的效果就是减少了数据库读的次数,减少了复杂数据的计算次数,从而提高了服务 ...
- gin框架中设置信任代理IP并获取远程客户端IP
package main import ( "fmt" "github.com/gin-gonic/gin" ) func main() { gin.SetMo ...
- String类(获取,转换,判断,比较)
1 package cn.itcast.p1.string.demo; 2 3 import java.util.Iterator; 4 5 import com.sun.org.apache.xpa ...