用单机单卡训练模型的时代已经过去,单机多卡已经成为主流配置。如何最大化发挥多卡的作用呢?本文介绍Pytorch中的DistributedDataParallel方法。

1. DataParallel

其实Pytorch早就有数据并行的工具DataParallel,它是通过单进程多线程的方式实现数据并行的。

简单来说,DataParallel有一个参数服务器的概念,参数服务器所在线程会接受其他线程传回来的梯度与参数,整合后进行参数更新,再将更新后的参数发回给其他线程,这里有一个单对多的双向传输。因为Python语言有GIL限制,所以这种方式并不高效,比方说实际上4卡可能只有2~3倍的提速。

2. DistributedDataParallel

Pytorch目前提供了更加高效的实现,也就是DistributedDataParallel。从命名上比DataParallel多了一个分布式的概念。首先 DistributedDataParallel是能够实现多机多卡训练的,但考虑到大部分的用户并没有多机多卡的环境,本篇博文主要介绍单机多卡的用法。

从原理上来说,DistributedDataParallel采用了多进程,避免了python多线程的效率低问题。一般来说,每个GPU都运行在一个单独的进程内,每个进程会独立计算梯度。

同时DistributedDataParallel抛弃了参数服务器中一对多的传输与同步问题,而是采用了环形的梯度传递,这里引用知乎上的图例。这种环形同步使得每个GPU只需要和自己上下游的GPU进行进程间的梯度传递,避免了参数服务器一对多时可能出现的信息阻塞。

3. DistributedDataParallel示例

下面给出一个非常精简的单机多卡示例,分为六步实现单机多卡训练。

第一步,首先导入相关的包。

import argparse
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

第二步,加一个参数,local_rank。这比较好理解,相当于就是告知当前的程序跑在那一块GPU上,也就是下面的第三行代码。local_rank是通过pytorch的一个启动脚本传过来的,后面将说明这个脚本是啥。最后一句是指定通信方式,这个选nccl就行。

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl')

第三步,包装Dataloader。这里需要的是将sampler改为DistributedSampler,然后赋给DataLoader里面的sampler。

为什么需要这样做呢?因为每个GPU,或者说每个进程都会从DataLoader里面取数据,指定DistributedSampler能够让每个GPU取到不重叠的数据。

读者可能会比较好奇,在下面指定了batch_size为24,这是说每个GPU都会被分到24个数据,还是所有GPU平分这24条数据呢?答案是,每个GPU在每个iter时都会得到24条数据,如果你是4卡,一个iter中总共会处理24*4=96条数据。

train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)

trainloader = torch.utils.data.DataLoader(my_trainset,batch_size=24,num_workers=4,sampler=train_sampler)

第四步,使用DDP包装模型。device_id仍然是args.local_rank。

model = DDP(model, device_ids=[args.local_rank])

第五步,将输入数据放到指定GPU。后面的前后向传播和以前相同。

for imgs,labels in trainloader:

    imgs=imgs.to(args.local_rank)
labels=labels.to(args.local_rank) optimizer.zero_grad()
output=net(imgs)
loss_data=loss(output,labels)
loss_data.backward()
optimizer.step()

第六步,启动训练。torch.distributed.launch就是启动脚本,nproc_per_node是GPU数。

python -m torch.distributed.launch --nproc_per_node 2 main.py

通过以上六步,我们就让模型跑在了单机多卡上。是不是也没有那么麻烦,但确实要比DataParallel复杂一些,考虑到加速效果,不妨试一试。

4. DistributedDataParallel注意点

DistributedDataParallel是多进程方式执行的,那么有些操作就需要小心了。如果你在代码中写了一行print,并使用4卡训练,那么你将会在控制台看到四行print。我们只希望看到一行,那该怎么做呢?

像下面一样加一个判断即可,这里的get_rank()得到的是进程的标识,所以输出操作只会在进程0中执行。

if dist.get_rank() == 0:
print("hah")

你会经常需要dist.get_rank()的。因为有很多操作都只需要在一个进程里执行,比如保存模型,如果不加以上判断,四个进程都会写模型,可能出现写入错误;另外load预训练模型权重时,也应该加入判断,只load一次;还有像输出loss等一些场景。

【参考】https://zhuanlan.zhihu.com/p/178402798

Pytorch分布式训练的更多相关文章

  1. [源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

    [源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 目录 [源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 0x00 摘要 0x01 分布式并 ...

  2. [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

    [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 目录 [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 0x00 摘要 0 ...

  3. [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

    [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 目录 [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 0x00 摘要 0x01 架构图 ...

  4. [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

    [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 目录 [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 0x00 摘要 0x01 ...

  5. [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

    [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...

  6. [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

    [源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...

  7. [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer

    [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...

  8. [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

    [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播 目录 [源码解析] PyTorch 分布式(12) ----- Distribu ...

  9. [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

    [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播 目录 [源码解析] PyTorch 分布式(13) ----- Distribu ...

随机推荐

  1. 《Streaming Systems》第二章: 数据处理中的 What, Where, When, How

    本章中,我们将通过对 What,Where,When,How 这 4 个问题的回答,逐步揭开流处理过程的全貌. What:计算什么结果? 也就是我们进行数据处理的目的,答案是转换(transforma ...

  2. 容器内的Linux诊断工具0x.tools

    原创:扣钉日记(微信公众号ID:codelogs),欢迎分享,转载请保留出处. 简介 Linux上有大量的问题诊断工具,如perf.bcc等,但这些诊断工具,虽然功能强大,但却需要很高的权限才可以使用 ...

  3. 使用本地自签名证书为 React 项目启用 https 支持

    简介 现在是大前端的时代,我们在本地开发 React 项目非常方便.这不是本文的重点,今天要分享一个话题是,如何为这些本地的项目,添加 https 的支持.为什么要考虑这个问题呢?主要有几个原因 如果 ...

  4. 隐藏浏览器header中X-Powered-By: PHP信息

    在php程序中,默认会在http请求响应头中输出php版本信息.如下: HTTP/1.1 200 OK Content-Type: text/html; charset=utf-8 Date: Tue ...

  5. 1903021121-刘明伟-java十一周作业-java面向对象编程

    项目 内容 课程班级博客链接 19级信计班(本) 作业要求链接 第十一周作业 博客名称 1903021121-刘明伟-java十一周作业-java面向对象 要求 每道题要有题目,代码(使用插入代码,不 ...

  6. 渗透:wesside-ng

    WEP自动破解工具wesside-ng wesside-ng是aircrack-ng套件提供的一个概念验证工具.该工具可以自动扫描无线网络,发现WEP加密的AP.然后,尝试关联该AP.关联成功后,它会 ...

  7. 145_Power BI Report Server自定义Form登录

    博客:www.jiaopengzi.com 焦棚子的文章目录 请点击下载附件 1.背景 很久没有更新Power BI Report Server了,发现自己机器还是2021年1月版本的,现在更新了20 ...

  8. Servlet表单数据

    1.GET 方法 GET 方法向页面请求发送已编码的用户信息.页面和已编码的信息中间用 ? 字符分隔,如下所示: http://www.test.com/hello?key1=value1&k ...

  9. CA证书介绍与格式转换

    CA证书介绍与格式转换 概念 PKCS 公钥加密标准(Public Key Cryptography Standards, PKCS),此一标准的设计与发布皆由RSA资讯安全公司(英语:RSA Sec ...

  10. Jetpack架构组件学习(3)——Activity Results API使用

    原文地址:Jetpack架构组件学习(3)--Activity Results API使用 - Stars-One的杂货小窝 技术与时俱进,页面跳转传值一直使用的是startActivityForRe ...