计图MPI分布式多卡

计图分布式基于MPI(Message Passing Interface),主要阐述使用计图MPI,进行多卡和分布式训练。目前计图分布式处于测试阶段。

计图MPI安装

计图依赖OpenMPI,用户可以使用如下命令安装OpenMPI:

sudo apt install openmpi-bin openmpi-common libopenmpi-dev

计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,输出如下信息:

[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc

如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:export mpicc_path=/you/mpicc/path

OpenMPI安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式,自动完成并行操作。

# 单卡训练代码

python3.7 -m jittor.test.test_resnet

# 分布式多卡训练代码

mpirun -np 4 python3.7 -m jittor.test.test_resnet

# 指定特定显卡的多卡训练代码

CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet

便捷性的背后,计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发,均在Python前端完成,让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。

基于这些mpi算子接口,研发团队已经集成了如下三种分布式相关的算法:

  • 分布式数据并行加载
  • 分布式优化器
  • 分布式同步批归一化层

用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是所有节点的batch size之和,也就是总batch size,不是单个节点接收到的batch size。

MPI接口

目前MPI开放接口如下:

  • jt.mpi: 计图的MPI模块,当计图不在MPI环境下时,jt.mpi == None, 用户可以用这个判断是否在mpi环境下。
  • jt.Module.mpi_param_broadcast(root=0): 将模块的参数从root节点广播给其他节点。
  • jt.mpi.mpi_reduce(x, op='add', root=0): 将所有节点的变量x使用算子op,reduce到root节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

  • jt.mpi.mpi_broadcast(x, root=0): 将变量x从root节点广播到所有节点。

  • jt.mpi.mpi_all_reduce(x, op='add'): 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

实例:MPI实现分布式同步批归一化层

下面的代码是使用计图实现分布式同步批,归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下:

# 将均值和方差,通过all reduce同步到所有节点

if self.sync and jt.mpi:

xmean = xmean.mpi_all_reduce("mean")

x2mean = x2mean.mpi_all_reduce("mean")

注:计图内部已经实现了同步的批归一化层,用户不需要自己实现

分布式同步批归一化层的完整代码:

class BatchNorm(Module):

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):

assert affine == None

self.sync = sync

self.num_features = num_features

self.is_train = is_train

self.eps = eps

self.momentum = momentum

self.weight = init.constant((num_features,), "float32", 1.0)

self.bias = init.constant((num_features,), "float32", 0.0)

self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()

self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()

def execute(self, x):

if self.is_train:

xmean = jt.mean(x, dims=[0,2,3], keepdims=1)

x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)

# 将均值和方差,通过all reduce同步到所有节点

if self.sync and jt.mpi:

xmean = xmean.mpi_all_reduce("mean")

x2mean = x2mean.mpi_all_reduce("mean")

xvar = x2mean-xmean*xmean

norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)

self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum

self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum

else:

running_mean = self.running_mean.broadcast(x, [0,2,3])

running_var = self.running_var.broadcast(x, [0,2,3])

norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)

w = self.weight.broadcast(x, [0,2,3])

b = self.bias.broadcast(x, [0,2,3])

return norm_x * w + b

计图MPI分布式多卡的更多相关文章

  1. 计图(Jittor) 1.1版本:新增骨干网络、JIT功能升级、支持多卡训练

    计图(Jittor) 1.1版本:新增骨干网络.JIT功能升级.支持多卡训练 深度学习框架-计图(Jittor),Jittor的新版本V1.1上线了.主要变化包括: 增加了大量骨干网络的支持,增强了辅 ...

  2. openlayers-统计图显示(中国区域高亮)

    openlayers版本: v3.19.1-dist 统计图效果:         案例下载地址:https://gitee.com/kawhileonardfans/openlayers-examp ...

  3. 用动图讲解分布式 Raft

    一.Raft 概述 Raft 算法是分布式系统开发首选的共识算法.比如现在流行 Etcd.Consul. 如果掌握了这个算法,就可以较容易地处理绝大部分场景的容错和一致性需求.比如分布式配置系统.分布 ...

  4. 8.3 MPI

    MPI 模型 如图MPI的各个运算节点是分布式的.每一个节点可以视为是一个“Thread”,但这里的不同之处在于这些节点没有所谓的共享内存,或者说Global Memory.所以,在后面也会看到,一般 ...

  5. Horovod 分布式深度学习框架相关

    最近需要 Horovod 相关的知识,在这里记录一下,进行备忘: 分布式训练,分为数据并行和模型并行两种: 模型并行:分布式系统中的不同GPU负责网络模型的不同部分.神经网络模型的不同网络层被分配到不 ...

  6. Samsung S4卡屏卡在开机画面的不拆机恢复照片一例

    大家好!欢迎再次来到我Dr.wonder的世界, 今天我给你们带来Samsung S4 I9508 卡屏开在开机画面的恢复!非常de经典. 首先看图 他开机一直卡在这里, 然后 ,我们使用专业仪器,在 ...

  7. 云时代的分布式数据库:阿里分布式数据库服务DRDS

    发表于2015-07-15 21:47| 10943次阅读| 来源<程序员>杂志| 27 条评论| 作者王晶昱 <程序员>杂志数据库DRDS分布式沈询 摘要:伴随着系统性能.成 ...

  8. Spark入门实战系列--9.Spark图计算GraphX介绍及实例

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .GraphX介绍 1.1 GraphX应用背景 Spark GraphX是一个分布式图处理 ...

  9. 学习笔记:The Log(我所读过的最好的一篇分布式技术文章)

    前言 这是一篇学习笔记. 学习的材料来自Jay Kreps的一篇讲Log的博文. 原文很长,但是我坚持看完了,收获颇多,也深深为Jay哥的技术能力.架构能力和对于分布式系统的理解之深刻所折服.同时也因 ...

随机推荐

  1. 使用DirectX截屏

    网上有很多关于DirectX截屏的文章,但大都是屏幕截图,很少有窗口截图,本文则两者都涉及到,先讲如何截取整个屏幕,再讲如何截取某个窗口,其实二者的区别不大,只是某个参数的设置不同而已,最后我们还将扩 ...

  2. apk 脱壳

    在理解android的类加载后,我们可以愉快对apk来脱壳了.脱壳重要的是断点: 断点:在哪个位置脱壳,这里着重指的是在哪个方法 先介绍断点,我们只要知道加壳是用哪个方法来加载dex的,hook这个方 ...

  3. C#-获取磁盘,cpu,内存信息

    获取磁盘信息 zongdaxiao = GetHardDiskSpace("C") * 1.0 / 1024; user = GetHardDiskFreeSpace(" ...

  4. MinGW 可以编译驱动的

    #include <ddk/ntddk.h> static VOID STDCALLmy_unload( IN PDRIVER_OBJECT DriverObject ) {} NTSTA ...

  5. CentOS安装Redis报错[server.o] Error 1

    原因 准备安装的Redis服务版本为6.0.8, gcc的版本为4.8.5,可能是gcc版本过低到导致的 解决办法 安装低版本Redis或者安装高版本gcc

  6. Object划分

    Object划分 1.PO(persistantobject)持久对象 PO就是对应数据库中某个表中的一条记录,多个记录可以用PO的集合.PO中应该不包 含任何对数据库的操作. 2.DO(Domain ...

  7. spring中注解@Resource 与@Autowire 区别

    ① .@Resource 是根据名字进行自动装配:@Autowire是通过类型进行装配. ②. @Resource 注解是 jdk 的:@Autowire 是spring的.

  8. Raspberry PI 4B 安装和配置 Raspbian

    做记录,以备之后需要,待完成中 目录 做记录,以备之后需要,待完成中 下载镜像和安装程序 ssh 远程访问 下载镜像和安装程序 Raspbian: installer: ssh 远程访问 开启ssh ...

  9. IOS Widget(5):小组件刷新机制

    引言   前面的章节学完已经让我们可以顺利实现一个小组件了,但是小组件里面的数据如何刷新的呢,本节内容将讲解IOS的刷新机制. 大纲 系统如何管理小组件刷新 Timeline刷新机制 Timeline ...

  10. Python JWT 介绍

    Python JWT 介绍 目录 Python JWT 介绍 1. JWT 介绍 2. JWT 创建 token 2.1 JWT 生成原理 2.2 JWT 校验 token 原理 3. 代码实现 4. ...