1.DataParallel layers (multi-GPU, distributed)

1)DataParallel

CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=)

实现模块级别的数据并行

该容器是通过在batch维度上将输入分到指定的device中来在给定的module应用上实现并行。在前向传播中,模块module将在每个设备device上都复制一个,然后每个复制体都会处理一部分的输入。在后向传播阶段,从每个复制体中计算得到的梯度会加在一起求和,然后传给最初的模块module

batch size的大小应该比GPU的个数大

也可见Use nn.DataParallel instead of multiprocessing

任意位置和关键字输入都允许被传给DataParallel,但是一些类型将进行特定处理。tensors将在指定的维度dim(默认为0)上被分散。tuple、list和dict类型将被浅复制。其他类型将在不同的线程中共享,且在写入模型的前向传播时可以被打断

在运行该DataParallel模块浅,并行module必须在device_ids[0]有自己的参数和缓冲区

警告⚠️

在每次前向传播中,module在每个设备中复制一遍,所以任何对该正在运行的module的forward的更新都将丢失。比如,如果module有一个计算属性,在每次forward后自增,那么该属性将保持在初始状态,因为已经进行更新的复制体在forward后将被摧毁。但是,DataParallel能保证在device[0]上的复制体的参数和缓冲区与基本并行module共享存储。因此对参数或device[0]上的缓冲区的内置更新将被保留。比如BatchNorm2dspectral_norm()就依赖于该特性更新缓冲区

警告⚠️

定义在module中的前向后向hooks和它的子模块将被调用len(device_ids)次,每一次都带着对应设备上的输入。尤其是hooks仅保证根据相应设备上的操作以正确顺序被执行。比如,不保证通过register_forward_pre_hook()被设置的hooks在所有len(device_ids)次forward()调用之前被执行,但是每个这样的hook将在该设备的相应forward()调用之前被执行

警告⚠️

当module在forward()返回scalar(比如0维度的tensor),该封装器将返回一个有着与使用在数据并行中的设备数量相同的长度向量,包含着来自每个设备的结果

注意⚠️

在封装在DataParallel中的模块module中使用pack sequence -> recurrent network -> unpack sequence模式是十分微妙的。详情可见My recurrent network doesn’t work with data parallelism

参数:

  • module(Module):并行模块module
  • device_ids(int类型列表或torch.device): CUDA设备(默认为所有设备)
  • output_device(int或torch.device):输出的设备位置(默认为device_ids[0])
Variables:

~DataParallel.module (Module) :并行模块

例子:

>>> net = torch.nn.DataParallel(model, device_ids=[, , ])
>>> output = net(input_var) # input_var can be on any device, including CPU

等价于下面的方法:

2.DataParallel functions (multi-GPU, distributed)

1)data_parallel

torch.nn.parallel.data_parallel(module, inputs, device_ids=None, output_device=None, dim=, module_kwargs=None)

该给定的device_ids的GPU上并行执行module(input)

这个是上面的DataParallel模块的函数版本

参数:

  • module(Module):并行模块module
  • input(Tensor):module的输入
  • device_ids(int类型列表或torch.device): CUDA设备(默认为所有设备)
  • output_device(int或torch.device):输出的设备位置(默认为device_ids[0])

返回:

包含着位于指定output_device上的model(input)的结果的Tensor

例子:

import torch.nn.parallel
...
def forward(self,input):
if self.ngpu>:
output=nn.parallel.data_parallel(self.model,input,range(self.ngpu)) #在多个gpu上运行模型,并行计算
else:
output=self.model(input) return output

pytorch设置多GPU运行的方法的更多相关文章

  1. [源码解析] PyTorch 如何使用GPU

    [源码解析] PyTorch 如何使用GPU 目录 [源码解析] PyTorch 如何使用GPU 0x00 摘要 0x01 问题 0x02 移动模型到GPU 2.1 cuda 操作 2.2 Modul ...

  2. Pytorch使用多GPU

    在caffe中训练的时候如果使用多GPU则直接在运行程序的时候指定GPU的index即可,但是在Pytorch中则需要在声明模型之后,对声明的模型进行初始化,如: cnn = DataParallel ...

  3. GPU运行Tensorflow的几点建议

    1.在运行之前先查看GPU的使用情况: 指令:nvidia-smi 备注:查看GPU此时的使用情况 或者 指令:watch nvidia-smi 备注:实时返回GPU使用情况 2.指定GPU训练: 方 ...

  4. Linux scp 设置nohup后台运行

    Linux scp 设置nohup后台运行 1.正常执行scp命令 2.输入ctrl + z 暂停任务 3.bg将其放入后台 4.disown -h 将这个作业忽略HUP信号 5.测试会话中断,任务继 ...

  5. IIS7.5使用web.config设置伪静态的二种方法

    转自 网上赚钱自学网 .http://www.whosmall.com/post/121 近几天公司里开发的项目有几个运行在IIS7.5上,由于全站采用的是伪静态,因此从网上找到两两种方法来实现.这两 ...

  6. CentOS设置服务开机启动的方法

    CentOS设置服务开机启动的两种方法 1.利用 chkconfig 来配置启动级别在CentOS或者RedHat其他系统下,如果是后面安装的服务,如httpd.mysqld.postfix等,安装后 ...

  7. 微信JS-SDK“分享信息设置”API及数字签名生成方法(NodeJS版本)

    原文:微信JS-SDK"分享信息设置"API及数字签名生成方法(NodeJS版本) 先上测试地址以示成功: 用微信打开下面地址测试 http://game.4gshu.com/de ...

  8. 服务器编程入门(13) Linux套接字设置超时的三种方法

    摘要:     本文介绍在套接字的I/O操作上设置超时的三种方法. 图片可能有点宽,看不到的童鞋可以点击图片查看完整图片.. 1 调用alarm 使用SIGALRM为connect设置超时 设置方法: ...

  9. 怎么进入bios设置界面,电脑如何进入BIOS进行设置,怎么进入BIOS的方法集合

    怎么进入bios设置界面,电脑如何进入BIOS进行设置,怎么进入BIOS的方法集合 开机出现电脑商家图标时,按住F10键进入BIOS界面.进入BIOS界面一般都是开机后按<del,Esc,F1, ...

随机推荐

  1. P1801 黑匣子[堆]

    题目描述 Black Box是一种原始的数据库.它可以储存一个整数数组,还有一个特别的变量i.最开始的时候Black Box是空的.而i等于0.这个Black Box要处理一串命令. 命令只有两种: ...

  2. Windows 窗体的自适应分辨率、分屏显示、开机自启动

    前言 这里所说的针对Winform.WPF 都适用.开机自启动对于控制台的也可以. 还是从项目实践中得来的,在这里记录下来. 对于自适应.分屏显示,在以前感觉应该比较高大上的问题,会比较难.在经过这次 ...

  3. rocketmq那些事儿之集群环境搭建

    上一篇入门基础部分对rocketmq进行了一个基础知识的讲解说明,在正式使用前我们需要进行环境的搭建,今天就来说一说rockeketmq分布式集群环境的搭建 前言 之前已经介绍了rocketmq的入门 ...

  4. cookie,session,token介绍

    本文目录 发展史 Cookie Session Token 回到目录 发展史 1.很久很久以前,Web 基本上就是文档的浏览而已, 既然是浏览,作为服务器, 不需要记录谁在某一段时间里都浏览了什么文档 ...

  5. npm安装模块没有权限解决办法

    直接加上unsafe的参数即可 sudo npm install --unsafe-perm --verbose -g sails

  6. python 中 super函数的使用

    转载地址:http://python.jobbole.com/86787/ 1.简单的使用 在类的继承中,如果重定义某个方法,该方法会覆盖父类的同名方法,但有时,我们希望能同时实现父类的功能,这时,我 ...

  7. MySQL 索引原理以及慢查询优化

    本文以MySQL数据库为研究对象,讨论与数据库索引相关的一些话题.特别需要说明的是,MySQL支持诸多存储引擎,而各种存储引擎对索引的支持也各不相同,因此MySQL数据库支持多种索引类型,如BTree ...

  8. P4357 [CQOI2016]K远点对

    题意:给定平面中的 \(n\) 个点,求第 \(K\) 远的点对之间的距离,\(n\leq 1e5,K\leq min(100,\frac{n\times (n-1)}{2})\) 题解:kd-tre ...

  9. 020_Python3 File(文件) 方法

    1.open() 方法 Python open() 方法用于打开一个文件,并返回文件对象,在对文件进行处理过程都需要使用到这个函数,如果该文件无法被打开,会抛出 OSError. 注意:使用 open ...

  10. asp.net文件夹上传下载组件

    ASP.NET上传文件用FileUpLoad就可以,但是对文件夹的操作却不能用FileUpLoad来实现. 下面这个示例便是使用ASP.NET来实现上传文件夹并对文件夹进行压缩以及解压. ASP.NE ...