Pytorch之Dataparallel源码解析
之前对Pytorch 1.0 的Dataparallel的使用方法一直似懂非懂,总是会碰到各种莫名其妙的问题,今天就好好从源头梳理一下,更好地理解它的原理或者说说下步骤。
源码地址: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py
初始化
首先我们一行一行地来看一下Dataparallel是如何初始化的。
super
就是继承torch.nn.Module父类,这里不做解释- 第一个if判断语句:检查是否有可用GPU
- 第二个if判断语句:如果没有指定GPU,则默认使用所有可用的GPU
- 第三个if判断语句:
output_device
表示输出到哪一个GPU上,默认是第一个GPU,注意这个第一个是device_ids列表上的第一个,所以如果你有三个GPU,而你在将model复制到cuda上时写的代码是model.cuda(1)
或者model.cuda(2)
,则会报错,因为device_ids
是[0,1,2].其第一个元素是0。这一点可以在后面的forward
函数中看到。 - emm,后面每行代码的作用很清楚,就不再一一解释了。
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
前向传播
下面进入到重头戏:Dataparallel的forward函数。
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
- 第一个if判断语句:如果没有可用的GPU设备,则使用原来的module进行计算。
- for循环就是对应了前面提到的问题,用于检查model和input是不是放在第一个GPU上
- 之后下一步就是将将input平均划分到每个GPU上,用到的是下面的
scatter
函数
def scatter(inputs, target_gpus, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
res = scatter_map(inputs)
finally:
scatter_map = None
return res
- 数据划分之后呢,再判断一下有几个可用的GPU(前面是判断有没有,这里是判断有几个),如果只有一个GPU,那就不用进入到下一步了。
- 如果有多个GPU,那么就需要用到
replica
函数,这个函数比较复杂,就不解释了,感兴趣的可以阅读一下源码:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py 。不过它的主要作用就是将模型复制到多个GPU上。 - 下一步中的
parallel_apply
作用就是并行地在多个GPU上计算模型,每个模型是一样的,只不过输入数据是不一样的,因为前面将数据平均划分了。例如你有两个GPU,一个batch大小是64,那么两个GPU分别处理batch大小为32的数据。 - 最后就是将输出值
gather
到一起,传送到output_device
,即第一个GPU设备上。
Pytorch之Dataparallel源码解析的更多相关文章
- [源码解析] PyTorch 分布式(2) ----- DataParallel(上)
[源码解析] PyTorch 分布式(2) ----- DataParallel(上) 目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 0 ...
- [源码解析] PyTorch 分布式(3) ----- DataParallel(下)
[源码解析] PyTorch 分布式(3) ----- DataParallel(下) 目录 [源码解析] PyTorch 分布式(3) ----- DataParallel(下) 0x00 摘要 0 ...
- [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- [源码解析] PyTorch 流水线并行实现 (1)--基础知识
[源码解析] PyTorch 流水线并行实现 (1)--基础知识 目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 t ...
- [源码解析] PyTorch 流水线并行实现 (5)--计算依赖
[源码解析] PyTorch 流水线并行实现 (5)--计算依赖 目录 [源码解析] PyTorch 流水线并行实现 (5)--计算依赖 0x00 摘要 0x01 前文回顾 0x02 计算依赖 0x0 ...
- [源码解析] PyTorch 分布式(1)------历史和概述
[源码解析] PyTorch 分布式(1)------历史和概述 目录 [源码解析] PyTorch 分布式(1)------历史和概述 0x00 摘要 0x01 PyTorch分布式的历史 1.1 ...
- [源码解析] PyTorch 如何使用GPU
[源码解析] PyTorch 如何使用GPU 目录 [源码解析] PyTorch 如何使用GPU 0x00 摘要 0x01 问题 0x02 移动模型到GPU 2.1 cuda 操作 2.2 Modul ...
- [源码解析] PyTorch 分布式(4)------分布式应用基础概念
[源码解析] PyTorch 分布式(4)------分布式应用基础概念 目录 [源码解析] PyTorch 分布式(4)------分布式应用基础概念 0x00 摘要 0x01 基本概念 0x02 ...
随机推荐
- #ifndef #define #endif
在一个大的软件工程里面,可能会有多个文件同时包含一个头文件,当这些文件编译链接成一个可执行文件时,就会出现大量重定义的错误.在头文件中实用#ifndef #define #endif能避免头文件的重定 ...
- c++隐式转换(implicit conversion)
1.缘由 最近在使用nlohmann的json,发现有些地方不是特别好用,所以就想自己修改一下(目的是为了增加类似jsoncpp中可以//增加注释的功能),在看源码的时候看到了一个迷惑的地方,就是解析 ...
- 冰多多团队Gamma阶段发布说明
Bingduoduo 语音Coding(Gamma):项目Github地址 Gamma版本新功能介绍 在gamma阶段我们推出了一个更加完整的IDE,完善了部分编辑器功能,并且优化了UI,增添了新的s ...
- 在导入pytorch时libmkl_intel_lp64.so找不到
安装或者更新完pytorch后,运行不了,显示错误: (base) xu@xusu:~$ python Python (default, Dec , ::) [GCC ] :: Anaconda, I ...
- cad.net 图层隐藏 IsHidden 用法 eDuplicateRecordName 报错
提要:影响图层显示的主要有:关闭 isOff冻结 IsFrozen 图层隐藏 isHidden视口冻结 FreezeLayersInViewport 今天小博发现了一件事情 ...
- kafka topic查看删除
1,查看kafka topic列表,使用--list参数 >bin/kafka-topics.sh --zookeeper 127.0.0.1:2181 --list __consumer_of ...
- JavaScript核心知识点
一.JavaScript 简介 一.JavaScript语言的介绍:JavaScript是基于对象和原型的一种动态.弱类型的脚本语言 二.JavaScript语言的组成:JavaScript是由核心语 ...
- ubuntu更改源为aliyun的源;ROS改为新加坡源
在修改source.list前,最好先备份一份 sudo cp /etc/apt/sources.list /etc/apt/sources.list.old 执行命令打开source.list文件 ...
- dockerfile 命令
FROM 功能为指定基础镜像,并且必须是第一条指令. 如果不以任何镜像为基础,那么写法为:FROM scratch. 同时意味着接下来所写的指令将作为镜像的第一层开始 语法: FROM <ima ...
- 【C++】const,static和static const类型成员变量声明及其初始化
1)const定义的常量在超出其作用域之后其空间会被释放,而static定义的静态常量在函数执行后不会释放其存储空间 void f1() { ; cout<<x<<endl; ...