Pytorch分布式训练
用单机单卡训练模型的时代已经过去,单机多卡已经成为主流配置。如何最大化发挥多卡的作用呢?本文介绍Pytorch中的DistributedDataParallel方法。
1. DataParallel
其实Pytorch早就有数据并行的工具DataParallel,它是通过单进程多线程的方式实现数据并行的。
简单来说,DataParallel有一个参数服务器的概念,参数服务器所在线程会接受其他线程传回来的梯度与参数,整合后进行参数更新,再将更新后的参数发回给其他线程,这里有一个单对多的双向传输。因为Python语言有GIL限制,所以这种方式并不高效,比方说实际上4卡可能只有2~3倍的提速。
2. DistributedDataParallel
Pytorch目前提供了更加高效的实现,也就是DistributedDataParallel。从命名上比DataParallel多了一个分布式的概念。首先 DistributedDataParallel是能够实现多机多卡训练的,但考虑到大部分的用户并没有多机多卡的环境,本篇博文主要介绍单机多卡的用法。
从原理上来说,DistributedDataParallel采用了多进程,避免了python多线程的效率低问题。一般来说,每个GPU都运行在一个单独的进程内,每个进程会独立计算梯度。
同时DistributedDataParallel抛弃了参数服务器中一对多的传输与同步问题,而是采用了环形的梯度传递,这里引用知乎上的图例。这种环形同步使得每个GPU只需要和自己上下游的GPU进行进程间的梯度传递,避免了参数服务器一对多时可能出现的信息阻塞。

3. DistributedDataParallel示例
下面给出一个非常精简的单机多卡示例,分为六步实现单机多卡训练。
第一步,首先导入相关的包。
import argparse
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
第二步,加一个参数,local_rank。这比较好理解,相当于就是告知当前的程序跑在那一块GPU上,也就是下面的第三行代码。local_rank是通过pytorch的一个启动脚本传过来的,后面将说明这个脚本是啥。最后一句是指定通信方式,这个选nccl就行。
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl')
第三步,包装Dataloader。这里需要的是将sampler改为DistributedSampler,然后赋给DataLoader里面的sampler。
为什么需要这样做呢?因为每个GPU,或者说每个进程都会从DataLoader里面取数据,指定DistributedSampler能够让每个GPU取到不重叠的数据。
读者可能会比较好奇,在下面指定了batch_size为24,这是说每个GPU都会被分到24个数据,还是所有GPU平分这24条数据呢?答案是,每个GPU在每个iter时都会得到24条数据,如果你是4卡,一个iter中总共会处理24*4=96条数据。
train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
trainloader = torch.utils.data.DataLoader(my_trainset,batch_size=24,num_workers=4,sampler=train_sampler)
第四步,使用DDP包装模型。device_id仍然是args.local_rank。
model = DDP(model, device_ids=[args.local_rank])
第五步,将输入数据放到指定GPU。后面的前后向传播和以前相同。
for imgs,labels in trainloader:
imgs=imgs.to(args.local_rank)
labels=labels.to(args.local_rank)
optimizer.zero_grad()
output=net(imgs)
loss_data=loss(output,labels)
loss_data.backward()
optimizer.step()
第六步,启动训练。torch.distributed.launch就是启动脚本,nproc_per_node是GPU数。
python -m torch.distributed.launch --nproc_per_node 2 main.py
通过以上六步,我们就让模型跑在了单机多卡上。是不是也没有那么麻烦,但确实要比DataParallel复杂一些,考虑到加速效果,不妨试一试。
4. DistributedDataParallel注意点
DistributedDataParallel是多进程方式执行的,那么有些操作就需要小心了。如果你在代码中写了一行print,并使用4卡训练,那么你将会在控制台看到四行print。我们只希望看到一行,那该怎么做呢?
像下面一样加一个判断即可,这里的get_rank()得到的是进程的标识,所以输出操作只会在进程0中执行。
if dist.get_rank() == 0:
print("hah")
你会经常需要dist.get_rank()的。因为有很多操作都只需要在一个进程里执行,比如保存模型,如果不加以上判断,四个进程都会写模型,可能出现写入错误;另外load预训练模型权重时,也应该加入判断,只load一次;还有像输出loss等一些场景。
【参考】https://zhuanlan.zhihu.com/p/178402798
Pytorch分布式训练的更多相关文章
- [源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识
[源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 目录 [源码解析] 深度学习分布式训练框架 Horovod --- (1) 基础知识 0x00 摘要 0x01 分布式并 ...
- [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入
[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 目录 [源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入 0x00 摘要 0 ...
- [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架
[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 目录 [源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架 0x00 摘要 0x01 架构图 ...
- [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构
[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 目录 [源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构 0x00 摘要 0x01 ...
- [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化
[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...
- [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构
[源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...
- [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer
[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...
- [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播
[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播 目录 [源码解析] PyTorch 分布式(12) ----- Distribu ...
- [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播
[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播 目录 [源码解析] PyTorch 分布式(13) ----- Distribu ...
随机推荐
- jQuery与JavaScript与Ajax三者的区别与联系
简单总结: 1.JS是一门 前端语言. 2.Ajax是一门 技术,它提供了异步更新的机制,使用客户端与服务器间交换数据而非整个页面文档,实现页面的局部更新. 3.jQuery是一个 框架,它对JS进行 ...
- .NET混合开发解决方案9 WebView2控件的导航事件
系列目录 [已更新最新开发文章,点击查看详细] WebView2控件应用详解系列博客 .NET桌面程序集成Web网页开发的十种解决方案 .NET混合开发解决方案1 WebView2简介 .NE ...
- 513. Find Bottom Left Tree Value - LeetCode
Question 513. Find Bottom Left Tree Value Solution 题目大意: 给一个二叉树,求最底层,最左侧节点的值 思路: 按层遍历二叉树,每一层第一个被访问的节 ...
- 好客租房11-为什么脚手架使用jsx语法
为什么脚手架中可以使用jsx语法 1jsx不是标准的ECMAScript ,他是ECMAScript的语法扩展 2需要使用babel编译处理后 才能在浏览器环境中使用 3create-react-ap ...
- 利用 Onekey Theater 改善屏幕显示效果
介绍 Onekey Theater(一键影音),它是联想笔记本带的一键影音功能,使用它能够更改笔记本的显示效果和音效,以此模仿电影院的效果,为用户带来更好是视听效果及享受. 作用 之前的联想笔记本自带 ...
- MIT 6.824(Spring 2020) Lab1: MapReduce 文档翻译
首发于公众号:努力学习的阿新 前言 大家好,这里是阿新. MIT 6.824 是麻省理工大学开设的一门关于分布式系统的明星课程,共包含四个配套实验,实验的含金量很高,十分适合作为校招生的项目经历,在文 ...
- Mysql中文存储、显示及不区分大小写控制
刚开始使用mysql,以为安装了完了就可以使用了,结果是我太天真了.mysql5.7版本,默认严格区分大小写,并且不支持中文存储. 严格区分大小写,即A表和a表示两个不同的表 实例 修改 在/etc/ ...
- Java学习-第一部分-第一阶段-第二节:变量
变量 变量介绍 为什么需要变量 变量是程序的基本组成单位 不论是使用哪种高级程序语言编写程序,变量都是其程序的基本组成单位,比如: //变量有三个基本要素(类型+名称+值) class Test{ p ...
- powershell命令总结
2021-07-21 初稿 ps命令采用动词-名词的方式命名,不区分大小写.默认当前文件夹为当前路径./.除去-match使用正则表达式匹配外,其他都使用*和?通配符. 速查 管道命令 前一个的输出作 ...
- SQLite数据库损坏及其修复探究
数据库如何发生损坏 SQLite 数据库具有很强的抗损坏能力.在执行事务时如果发生应用程序崩溃.操作系统崩溃甚至电源故障,那么在下次访问数据库文件时,会自动回滚部分写入的事务.恢复过程是全自动的, ...