import torch
import torch.nn as nn
import ipdb class DataParallelModel(nn.Module): def __init__(self):
super().__init__()
self.block1 = nn.Linear(10, 20) def forward(self, x):
x = self.block1(x)
return x def data_parallel(module, input, device_ids, output_device=None):
if not device_ids:
return module(input) if output_device is None:
output_device = device_ids[0] replicas = nn.parallel.replicate(module, device_ids)
print(f"replicas:{replicas}") inputs = nn.parallel.scatter(input, device_ids)
print(f"inputs:{type(inputs)}")
for i in range(len(inputs)):
print(f"input {i}:{inputs[i].shape}") replicas = replicas[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs)
print(f"outputs:{type(outputs)}")
for i in range(len(outputs)):
print(f"output {i}:{outputs[i].shape}") result = nn.parallel.gather(outputs, output_device)
return result model = DataParallelModel()
x = torch.rand(16,10)
result = data_parallel(model.cuda(),x.cuda(), [0,1])
print(f"result:{type(result)}")

最后输出为

replicas:[DataParallelModel(
(block1): Linear(in_features=10, out_features=20, bias=True)
), DataParallelModel(
(block1): Linear(in_features=10, out_features=20, bias=True)
)]
inputs:<class 'tuple'>
input 0:torch.Size([8, 10])
input 1:torch.Size([8, 10])
outputs:<class 'list'>
output 0:torch.Size([8, 20])
output 1:torch.Size([8, 20])
result: torch.Size([16, 20])

可以看到整个流程如下:

  • replicas: 将模型复制若干份,这里只有两个GPU,所以复制两份
  • scatter: 将输入数据若干等分,这里划分成了两份,会返回一个tuple。因为batch size=16,所以刚好可以划分成8和8,那如果是15怎么办呢?没关系,它会自动划分成8和7,这个你自己可以做实验感受一下。
  • parallel_apply: 现在模型和数据都有了,所以当然就是并行化的计算咯,最后返回的是一个list,每个元素是对应GPU的计算结果。
  • gather:每个GPU计算完了之后需要将结果发送到第一个GPU上进行汇总,可以看到最终的tensor大小是[16,20],这符合预期。

MARSGGBO♥原创







2019-9-17

Pytorch并行计算:nn.parallel.replicate, scatter, gather, parallel_apply的更多相关文章

  1. JAVA NIO Scatter/Gather(矢量IO)

    矢量IO=Scatter/Gather:   在多个缓冲区上实现一个简单的IO操作.减少或避免了缓冲区拷贝和系统调用(IO)   write:Gather 数据从几个缓冲区顺序抽取并沿着通道发送,就好 ...

  2. 转:Java NIO系列教程(四) Scatter/Gather

    Java NIO开始支持scatter/gather,scatter/gather用于描述从Channel(译者注:Channel在中文经常翻译为通道)中读取或者写入到Channel的操作.分散(sc ...

  3. java的nio之:java的nio系列教程之Scatter/Gather

    一:Java NIO的scatter/gather应用概念 ===>Java NIO开始支持scatter/gather,scatter/gather用于描述从Channel(译者注:Chann ...

  4. Java基础知识强化之IO流笔记75:NIO之 Scatter / Gather

    1. Java NIO开始支持scatter/gather,scatter/gather用于描述从Channel(译者注:Channel在中文经常翻译为通道)中读取或者写入到Channel的操作. 分 ...

  5. Java NIO Scatter / Gather

    原文链接:http://tutorials.jenkov.com/java-nio/scatter-gather.html Java NIO发布时内置了对scatter / gather的支持.sca ...

  6. Java NIO中的通道Channel(二)分散/聚集 Scatter/Gather

    什么是Scatter/Gather scatter/gather指的在多个缓冲区上实现一个简单的I/O操作,比如从通道中读取数据到多个缓冲区,或从多个缓冲区中写入数据到通道: scatter(分散): ...

  7. NIO相关概念之Scatter / Gather

    Scatter /Gather 是java NIO中用来对channel的读取或者写入操作的特殊的形式的描述 Scatter(发散) 是指在读操作的时候,从chanel读取到的数据,写入到多个buff ...

  8. Java NIO系列教程(四) Scatter/Gather

    Java NIO开始支持scatter/gather,scatter/gather用于描述从Channel(译者注:Channel在中文经常翻译为通道)中读取或者写入到Channel的操作.分散(sc ...

  9. NIO学习笔记六:channel 之前数据传输及scatter/gather

    在Java NIO中,如果两个通道中有一个是FileChannel,那你可以直接将数据从一个channel传输到另外一个channel. FileChannel的transferFrom()方法可以将 ...

随机推荐

  1. leetcode 1110. 删点成林

    题目描述: 给出二叉树的根节点 root,树上每个节点都有一个不同的值. 如果节点值在 to_delete 中出现,我们就把该节点从树上删去,最后得到一个森林(一些不相交的树构成的集合). 返回森林中 ...

  2. CentOS 7搭建本地yum源和局域网yum源

    这两天在部署公司的测试环境,在安装各种中间件的时候,发现各种依赖都没有:后来一检查,发现安装的操作系统是CentOS Mini版,好吧,我认了:为了完成测试环境的搭建,我就搭建了一个局域网的yum源. ...

  3. [ASP.Net ]利用ashx搭建简易接口

    转载:https://blog.csdn.net/ZYD45/article/details/79939475 创建接口的方式有很多,像是Web api,nodejs等等 今天,主要介绍,利用ashx ...

  4. eclipse&myeclipse 生成jar包后,spring无法扫描到bean定义

    问题:eclipse&myeclipse 生成jar包后,spring无法扫描到bean定义 在使用getbean或者扫包时注入bean失败,但在IDE里是可以正常运行的? 原因:导出jar未 ...

  5. 【数据结构与算法】k-d tree算法

    k-d tree算法 k-d树(k-dimensional树的简称),是一种分割k维数据空间的数据结构.主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索). 应用背景 SIFT算法中做特征点 ...

  6. C语言注释风格

    注释风格 一.前言 注释是源码程序中非常重要的一部分,一般情况下,源程序有效注释量必须在20%以上. 注释的原则是有助于对程序的阅读理解,所以注释语言必须准确.易懂.简洁,注释不宜太多也不能太少,注释 ...

  7. Npoi 的使用

    npoi这个office写入,我个人有点不方便,但是因为需要使用所以不得不去用了. 原因: 1. 没文档 2. 网上的案例版本不同 3. 对于复杂列不好做处理 跟网上其他工具的对比,好处就是不需要依赖 ...

  8. python实现Huffman编码

    一.问题 利用二叉树的结构对Huffman树进行编码,实现最短编码 二.解决 # 构建节点类 class TreeNode: def __init__(self, data): "" ...

  9. Git系列 —— 记一次Mac上git push时总是403的错误

    问题: 今天从github上clone下一个项目,然后修改后git push时总是出现: remote:Permission to lixyou/rw-split-plugin.git defined ...

  10. reactiveX沉思(草稿)

    一.第一性原理 将异步的io.事件解释为observable.并借用observer的一些类概念进行处理. ReactiveX is a library for composing asynchron ...