用单机单卡训练模型的时代已经过去,单机多卡已经成为主流配置。如何最大化发挥多卡的作用呢?本文介绍Pytorch中的DistributedDataParallel方法。

1. DataParallel

其实Pytorch早就有数据并行的工具DataParallel,它是通过单进程多线程的方式实现数据并行的。

简单来说,DataParallel有一个参数服务器的概念,参数服务器所在线程会接受其他线程传回来的梯度与参数,整合后进行参数更新,再将更新后的参数发回给其他线程,这里有一个单对多的双向传输。因为Python语言有GIL限制,所以这种方式并不高效,比方说实际上4卡可能只有2~3倍的提速。

2. DistributedDataParallel

Pytorch目前提供了更加高效的实现,也就是DistributedDataParallel。从命名上比DataParallel多了一个分布式的概念。首先 DistributedDataParallel是能够实现多机多卡训练的,但考虑到大部分的用户并没有多机多卡的环境,本篇博文主要介绍单机多卡的用法。

从原理上来说,DistributedDataParallel采用了多进程,避免了python多线程的效率低问题。一般来说,每个GPU都运行在一个单独的进程内,每个进程会独立计算梯度。

同时DistributedDataParallel抛弃了参数服务器中一对多的传输与同步问题,而是采用了环形的梯度传递,这里引用知乎上的图例。这种环形同步使得每个GPU只需要和自己上下游的GPU进行进程间的梯度传递,避免了参数服务器一对多时可能出现的信息阻塞。

3. DistributedDataParallel示例

下面给出一个非常精简的单机多卡示例,分为六步实现单机多卡训练。

第一步,首先导入相关的包。

import argparse
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

第二步,加一个参数,local_rank。这比较好理解,相当于就是告知当前的程序跑在那一块GPU上,也就是下面的第三行代码。local_rank是通过pytorch的一个启动脚本传过来的,后面将说明这个脚本是啥。最后一句是指定通信方式,这个选nccl就行。

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl')

第三步,包装Dataloader。这里需要的是将sampler改为DistributedSampler,然后赋给DataLoader里面的sampler。

为什么需要这样做呢?因为每个GPU,或者说每个进程都会从DataLoader里面取数据,指定DistributedSampler能够让每个GPU取到不重叠的数据。

读者可能会比较好奇,在下面指定了batch_size为24,这是说每个GPU都会被分到24个数据,还是所有GPU平分这24条数据呢?答案是,每个GPU在每个iter时都会得到24条数据,如果你是4卡,一个iter中总共会处理24*4=96条数据。

train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)

trainloader = torch.utils.data.DataLoader(my_trainset,batch_size=24,num_workers=4,sampler=train_sampler)

第四步,使用DDP包装模型。device_id仍然是args.local_rank。

model = DDP(model, device_ids=[args.local_rank])

第五步,将输入数据放到指定GPU。后面的前后向传播和以前相同。

for imgs,labels in trainloader:

    imgs=imgs.to(args.local_rank)
labels=labels.to(args.local_rank) optimizer.zero_grad()
output=net(imgs)
loss_data=loss(output,labels)
loss_data.backward()
optimizer.step()

第六步,启动训练。torch.distributed.launch就是启动脚本,nproc_per_node是GPU数。

python -m torch.distributed.launch --nproc_per_node 2 main.py

通过以上六步,我们就让模型跑在了单机多卡上。是不是也没有那么麻烦,但确实要比DataParallel复杂一些,考虑到加速效果,不妨试一试。

4. DistributedDataParallel注意点

DistributedDataParallel是多进程方式执行的,那么有些操作就需要小心了。如果你在代码中写了一行print,并使用4卡训练,那么你将会在控制台看到四行print。我们只希望看到一行,那该怎么做呢?

像下面一样加一个判断即可,这里的get_rank()得到的是进程的标识,所以输出操作只会在进程0中执行。

if dist.get_rank() == 0:
print("hah")

你会经常需要dist.get_rank()的。因为有很多操作都只需要在一个进程里执行,比如保存模型,如果不加以上判断,四个进程都会写模型,可能出现写入错误;另外load预训练模型权重时,也应该加入判断,只load一次;还有像输出loss等一些场景。

【参考】https://zhuanlan.zhihu.com/p/178402798

Pytorch分布式训练的更多相关文章

  1. [源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

    [源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 目录 [源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 0x00 摘要 0x01 分布式并 ...

  2. [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

    [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 目录 [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 0x00 摘要 0 ...

  3. [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

    [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 目录 [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 0x00 摘要 0x01 架构图 ...

  4. [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

    [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 目录 [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 0x00 摘要 0x01 ...

  5. [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

    [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...

  6. [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

    [源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...

  7. [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer

    [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...

  8. [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

    [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播 目录 [源码解析] PyTorch 分布式(12) ----- Distribu ...

  9. [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

    [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播 目录 [源码解析] PyTorch 分布式(13) ----- Distribu ...

随机推荐

  1. pycharm 打包py程序为exe

    传送门 在终端输入 pyinstaller -F xxx.py -n 新名字 --noconsole --noconsole 去掉cmd命令窗口 -F 打包成一个文件 -D 打包成一个文件夹 -i 加 ...

  2. [AcWing 777] 字符串乘方

    点击查看代码 #include<iostream> using namespace std; string str; int main() { while (cin >> st ...

  3. [题解][YZOJ50104] 密码 | 简单计数

    同步发表于 Mina! 题目大意 对于满足以下要求的长度为 \(n\) 的序列进行计数: 序列的值域为 \([1,k]\); 对于序列的任意位置 \(p\in[1,n]\),可以找到至少一个 \(i\ ...

  4. 深度长文:深入理解Ceph存储架构

    点击上方"开源Linux",选择"设为星标" 回复"学习"获取独家整理的学习资料! 本文是一篇Ceph存储架构技术文章,内容深入到每个存储特 ...

  5. Linux vs Unix - Linux与Unix到底有什么不同?

    来自:Linux迷链接:https://www.linuxmi.com/linux-vs-unix.html Linux和Unix这两个术语可以互换地用来指同一操作系统.这在很大程度上是由于他们惊人的 ...

  6. Oracle 19c单实例部署

    目录 Oracle 19c单实例部署: 1.配置yum: 2.安装rpm包: 3.设置hostname: 4.配置hostname解析: 5.配置时钟同步服务(ntp): 6.检查及配置内核参数: 7 ...

  7. 有了这10个GitHub仓库,开发者如同buff加持

    摘要:列出了10个极好的仓库,它们为所有web和软件开发人员提供了巨大的价值. 本文分享自华为云社区<所有开发者都应该知道的10个GitHub仓库>,作者: Ocean2022 . 除了作 ...

  8. 关于我学git这档子事(2)

    将本地main分支push到远程dev分支(不同名分支间的push) 远程dev分支还未创建 (在push同时创建远程dev分支,并将本地main分支内容上传) git push -u --set-u ...

  9. NBMiner42.1版本发布,完全解锁30系LHR版本显卡

    2021年下半年,NVIDIA发布了LHR版本显卡,对显卡算力进行了限制. 2022年5月8日,NBMiner发布NBMiner_41.0版本,在最新的内核中加入了100%LHR解锁器,适用于Wind ...

  10. linux挂载新硬盘并进行分区格式化

    最近要给小伙伴们写几篇文章,关于<linux下误删除文件之后该如何恢复>.对于没有进程占用的文件想要进行数据恢复,不同的文件系统格式需要使用不同的工具,比如:ext4.xfs等.我找遍了我 ...