之前对Pytorch 1.0 的Dataparallel的使用方法一直似懂非懂,总是会碰到各种莫名其妙的问题,今天就好好从源头梳理一下,更好地理解它的原理或者说说下步骤。

源码地址: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py

初始化

首先我们一行一行地来看一下Dataparallel是如何初始化的。

  • super就是继承torch.nn.Module父类,这里不做解释
  • 第一个if判断语句:检查是否有可用GPU
  • 第二个if判断语句:如果没有指定GPU,则默认使用所有可用的GPU
  • 第三个if判断语句:output_device表示输出到哪一个GPU上,默认是第一个GPU,注意这个第一个device_ids列表上的第一个,所以如果你有三个GPU,而你在将model复制到cuda上时写的代码是model.cuda(1)或者model.cuda(2),则会报错,因为device_ids是[0,1,2].其第一个元素是0。这一点可以在后面的forward函数中看到。
  • emm,后面每行代码的作用很清楚,就不再一一解释了。
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__() if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0] self.dim = dim
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0])) _check_balance(self.device_ids) if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])

前向传播

下面进入到重头戏:Dataparallel的forward函数。

def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs) for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
  • 第一个if判断语句:如果没有可用的GPU设备,则使用原来的module进行计算。
  • for循环就是对应了前面提到的问题,用于检查model和input是不是放在第一个GPU上
  • 之后下一步就是将将input平均划分到每个GPU上,用到的是下面的scatter函数
def scatter(inputs, target_gpus, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
res = scatter_map(inputs)
finally:
scatter_map = None
return res
  • 数据划分之后呢,再判断一下有几个可用的GPU(前面是判断有没有,这里是判断有几个),如果只有一个GPU,那就不用进入到下一步了。
  • 如果有多个GPU,那么就需要用到replica函数,这个函数比较复杂,就不解释了,感兴趣的可以阅读一下源码:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py 。不过它的主要作用就是将模型复制到多个GPU上。
  • 下一步中的parallel_apply作用就是并行地在多个GPU上计算模型,每个模型是一样的,只不过输入数据是不一样的,因为前面将数据平均划分了。例如你有两个GPU,一个batch大小是64,那么两个GPU分别处理batch大小为32的数据。
  • 最后就是将输出值gather到一起,传送到output_device,即第一个GPU设备上。

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-6-2

Pytorch之Dataparallel源码解析的更多相关文章

  1. [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 0 ...

  2. [源码解析] PyTorch 分布式(3) ----- DataParallel(下)

    [源码解析] PyTorch 分布式(3) ----- DataParallel(下) 目录 [源码解析] PyTorch 分布式(3) ----- DataParallel(下) 0x00 摘要 0 ...

  3. [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...

  4. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  5. [源码解析] PyTorch 流水线并行实现 (1)--基础知识

    [源码解析] PyTorch 流水线并行实现 (1)--基础知识 目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 t ...

  6. [源码解析] PyTorch 流水线并行实现 (5)--计算依赖

    [源码解析] PyTorch 流水线并行实现 (5)--计算依赖 目录 [源码解析] PyTorch 流水线并行实现 (5)--计算依赖 0x00 摘要 0x01 前文回顾 0x02 计算依赖 0x0 ...

  7. [源码解析] PyTorch 分布式(1)------历史和概述

    [源码解析] PyTorch 分布式(1)------历史和概述 目录 [源码解析] PyTorch 分布式(1)------历史和概述 0x00 摘要 0x01 PyTorch分布式的历史 1.1 ...

  8. [源码解析] PyTorch 如何使用GPU

    [源码解析] PyTorch 如何使用GPU 目录 [源码解析] PyTorch 如何使用GPU 0x00 摘要 0x01 问题 0x02 移动模型到GPU 2.1 cuda 操作 2.2 Modul ...

  9. [源码解析] PyTorch 分布式(4)------分布式应用基础概念

    [源码解析] PyTorch 分布式(4)------分布式应用基础概念 目录 [源码解析] PyTorch 分布式(4)------分布式应用基础概念 0x00 摘要 0x01 基本概念 0x02 ...

随机推荐

  1. 疯了!同事又问我为什么不能用 isXXX

    最近在做Code Review,写下了这篇文章:代码写成这样,老夫无可奈何!,说多了都是泪啊.. 最近又有人同事跑过来质疑我: 为什么变量名取名不能用 isXXX 这种方式,这样有什么问题?! 醉了, ...

  2. bat修改文件内容

    #file.vbsSet fso = Wscript.CreateObject("Scripting.FileSystemObject")set f=fso.opentextfil ...

  3. jenkins pipeline使用方式

    pipeline 使用 使用groovy的一种DSL语言,流程控制 pipeline脚本同其他脚本语言一样,从上到下顺序执行,它的流程控制取决于Groovy表达式,为jenkins用户提供了更巨大的灵 ...

  4. Rsync学习之旅上

    rsync 简介 什么是rsync rsync是一款开源的,快速的,多功能的,可实现全量及增量的本地或远程数据同步备份的优秀工具. 全量:将全部数据,进行传输覆盖 增量:只传输差异部分的数据 实现增量 ...

  5. 【简记】修改Docker数据目录位置,包含镜像位置

    为啥要改? Docker安装后默认下载的位置在/var/lib/docker ,如果/var分区没有独立分出来,Linux下默认是与/根分区在一起.一般我们装Linux系统的时候,除了做邮件服务器外, ...

  6. wps金山文档在线编辑--.Net 接入指南

    一.申请成为服务商,对金山文档在线服务进行申请 ①进入官网 https://open.wps.cn/ ②申请后如下图,点击右下角的进入服务 ③申请成功后 ④数据回调URL一定是服务器地址,本次我使用的 ...

  7. CodeForces 1228F One Node is Gone

    洛谷题目页面传送门 & CodeForces题目页面传送门 给定一棵树\(T=(V,E),|V|=2^n-2,|E|=2^n-3\),输出所有的\(x\),使得存在一棵满二叉树\(T'\),将 ...

  8. Django--模型层进阶

    目录 QuerySet对象 可切片 可迭代 惰性查询 缓存机制 何时查询集不会被缓存? exists()与iterator()方法 exists() iterator() 中介模型 查询优化 表数据 ...

  9. Java 8 in Action

    https://www.cnblogs.com/HelloDeveloper/p/11404523.html /** * @param args */public static void main(S ...

  10. 【转载】如何查看本机电脑的公网IP

    在实际使用电脑的过程中,很多时候我们需要知道本地电脑的当前公网IP地址,我们都知道个人电脑的公网IP是不固定的,可能每天的对外公网IP都不一样,如果要查看当前本机电脑的对外公网IP,方法也很简单,直接 ...