预备知识

模型并行( model parallelism ):即把模型拆分放到不同的设备进行训练,分布式系统中的不同机器(GPU/CPU等)负责网络模型的不同部分 —— 例如,神经网络模型的不同网络层被分配到不同的机器,或者同一层内部的不同参数被分配到不同机器,如AlexNet的训练。

数据并行( data parallelism ):即把数据切分,输入到不同的机器有同一个模型的多个副本,每个机器分配到不同的数据,然后将所有机器的计算结果按照某种方式合并。

多进程最佳实践

torch.multiprocessing 是 Python 的 multiprocessing 多进程模块的替代品。它支持完全相同的操作,但对其进行了扩展,以便所有通过多进程队列 multiprocessing.Queue 发送的张量都能将其数据移入共享内存,而且仅将其句柄发送到另一个进程。

注意:

当张量 Tensor 被发送到另一个进程时,张量的数据和梯度 torch.Tensor.grad 都将被共享。

这一特性允许实现各种训练方法,如 Hogwild,A3C 或任何其他需要异步操作的训练方法。

一、CUDA 张量的共享

仅 Python 3 支持进程之间共享 CUDA 张量,我们可以使用 spawnforkserver 启动此类方法。 Python 2 中的 multiprocessing 多进程处理只能使用 fork 创建子进程,并且CUDA运行时不支持多进程处理。

警告:

CUDA API 规定输出到其他进程的共享张量,只要它们被这些进程使用时,都将持续保持有效。您应该小心并确保您共享的 CUDA 张量不会超出它应该的作用范围(不会出现作用范围延伸的问题)。这对于共享模型的参数应该不是问题,但应该小心地传递其他类型的数据。请注意,此限制不适用于共享的 CPU 内存。

也可以参阅: 使用 nn.DataParallel 替代多进程处理

二、最佳实践和技巧

1、避免和防止死锁

产生新进程时会出现很多错误,导致死锁最常见的原因是后台线程。如果有任何持有锁或导入模块的线程,并且 fork 被调用,则子进程很可能处于崩溃状态,并且会以不同方式死锁或失败。请注意,即使您没有这样做,Python 中内置的库也可能会,更不必说 多进程处理 了。multiprocessing.Queue 多进程队列实际上是一个非常复杂的类,它产生了多个用于序列化、发送和接收对象的线程,并且它们也可能导致上述问题。如果您发现自己处于这种情况,请尝试使用multiprocessing.queues.SimpleQueue ,它不使用任何其他额外的线程。

我们正在尽可能的为您提供便利,并确保这些死锁不会发生,但有些事情不受我们控制。如果您有任何问题暂时无法应对,请尝试到论坛求助,我们会查看是否可以解决问题。

2、重用通过队列发送的缓冲区

请记住,每次将张量放入多进程队列 multiprocessing.Queue 时,它必须被移动到共享内存中。如果它已经被共享,将会是一个空操作,否则会产生一个额外的内存拷贝,这会减慢整个过程。即使您有一组进程将数据发送到单个进程,也可以让它将缓冲区发送回去,这几乎是不占资源的,并且可以在发送下一批时避免产生拷贝动作。

3、异步多进程训练(如: Hogwild)

使用多进程处理 torch.multiprocessing,可以异步地训练一个模型,参数既可以一直共享,也可以周期性同步。在第一种情况下,我们建议发送整个模型对象,而在后者中,我们建议只发送状态字典 state_dict()

我们建议使用多进程处理队列 multiprocessing.Queue 在进程之间传递各种 PyTorch 对象。使用 fork 启动一个方法时,它也可能会继承共享内存中的张量和存储空间,但这种方式也非常容易出错,应谨慎使用,最好只能让高阶用户使用。而队列,尽管它们有时候不太优雅,却能在任何情况下正常工作。

警告:

你应该留意没有用 if __name__ =='__main__' 来保护的全局语句。如果使用了不同于 fork 启动方法,它们将在所有子进程中执行。

4、Hogwild

具体的 Hogwild 实现可以在 示例库 中找到,但为了展示代码的整体结构,下面还有一个最简单的示例:

import torch.multiprocessing as mp
from model import MyModel def train(model):
# 构建 data_loader,优化器等
for data, labels in data_loader:
optimizer.zero_grad()
loss_fn(model(data), labels).backward()
optimizer.step() # 更新共享的参数 if __name__ == '__main__':
num_processes = 4
model = MyModel()
# 注意:这是 "fork" 方法工作所必需的
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()

Reference

https://ptorch.com/news/176.html

Pytorch多进程最佳实践的更多相关文章

  1. PyTorch最佳实践,怎样才能写出一手风格优美的代码

    [摘要] PyTorch是最优秀的深度学习框架之一,它简单优雅,非常适合入门.本文将介绍PyTorch的最佳实践和代码风格都是怎样的. 虽然这是一个非官方的 PyTorch 指南,但本文总结了一年多使 ...

  2. (转载)PyTorch代码规范最佳实践和样式指南

    A PyTorch Tools, best practices & Styleguide 中文版:PyTorch代码规范最佳实践和样式指南 This is not an official st ...

  3. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  4. MySQL面试必考知识点:揭秘亿级高并发数据库调优与最佳实践法则

    做业务,要懂基本的SQL语句: 做性能优化,要懂索引,懂引擎: 做分库分表,要懂主从,懂读写分离... 数据库的使用,是开发人员的基本功,对它掌握越清晰越深入,你能做的事情就越多. 今天我们用10分钟 ...

  5. [转]10分钟梳理MySQL知识点:揭秘亿级高并发数据库调优与最佳实践法则

    转:https://mp.weixin.qq.com/s/RYIiHAHHStIMftQT6lQSgA 做业务,要懂基本的SQL语句: 做性能优化,要懂索引,懂引擎: 做分库分表,要懂主从,懂读写分离 ...

  6. python 工业日志模块 未来的python日志最佳实践

    目录 介绍 好的功能 安装方法 参数介绍 呆log 参数与 使用方法 版本说明 后期版本规划 todo 感谢 介绍 呆log:工业中,python日志模块,安装即用.理论上支持 python2, py ...

  7. ASP.NET跨平台最佳实践

    前言 八年的坚持敌不过领导的固执,最终还是不得不阔别已经成为我第二语言的C#,转战Java阵营.有过短暂的失落和迷茫,但技术转型真的没有想象中那么难.回头审视,其实单从语言本身来看,C#确实比Java ...

  8. 《AngularJS深度剖析与最佳实践》简介

    由于年末将至,前阵子一直忙于工作的事务,不得已暂停了微信订阅号的更新,我将会在后续的时间里尽快的继续为大家推送更多的博文.毕竟一个人的力量微薄,精力有限,希望大家能理解,仍然能一如既往的关注和支持sh ...

  9. ASP.NET MVC防范CSRF最佳实践

    XSS与CSRF 哈哈,有点标题党,但我保证这篇文章跟别的不太一样. 我认为,网站安全的基础有三块: 防范中间人攻击 防范XSS 防范CSRF 注意,我讲的是基础,如果更高级点的话可以考虑防范机器人刷 ...

随机推荐

  1. 【刷题】BZOJ 2243 [SDOI2011]染色

    Description 给定一棵有n个节点的无根树和m个操作,操作有2类: 1.将节点a到节点b路径上所有点都染成颜色c: 2.询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段), 如 ...

  2. 【BZOJ1914】数三角形(组合数,极角排序)

    [BZOJ1914]数三角形(组合数,极角排序) 题面 BZOJ权限题 良心洛谷 题解 这种姿势很吼啊,表示计算几何啥的一窍不通来着. 题目就是这样,正难则反,所以我们不考虑过原点的三角形, 反过来, ...

  3. BZOJ3835 [Poi2014]Supercomputer 【斜率优化】

    题目链接 BZOJ3835 题解 对于\(k\),设\(s[i]\)为深度大于\(i\)的点数 \[ans = max\{i + \lceil \frac{s[i]}{k}\} \rceil\] 最优 ...

  4. 【bzoj3570】 Cqoi2014—通配符匹配

    http://www.lydsy.com/JudgeOnline/problem.php?id=3507 (题目链接) 题意 给出一个主串,里面有些通配符,'*'可以代替任意字符串或者消失,'?'可以 ...

  5. debian7编译安装tengine添加lua和ldap模块

    1.安装开发环境 # aptitute update # aptitude install -y build-essential # aptitude install -y libldap2-dev ...

  6. python检测服务器是否ping通

    好想在2014结束前再赶出个10篇博文来,~(>_<)~,不写博客真不是一个好兆头,至少说明对学习的欲望和对知识的研究都不是那么积极了,如果说这1天的时间我能赶出几篇精致的博文,你们信不信 ...

  7. 在Kubernetes集群里安装微服务DevOps平台fabric8

    转载于https://blog.csdn.net/wzp1986/article/details/72128063?utm_source=itdadao&utm_medium=referral ...

  8. git更换 拉取推送地址

    更换:git remote set-url originhttp://git.caomall.cn/sucry/mouse.git」

  9. NATS_02:NATS消息通信模型

    消息通信模型 NATS的消息通信是这样的:应用程序的数据被编码为一条消息,并通过发布者发送出去:订阅者接收到消息,进行解码,再处理.订阅者处理NATS消息可以是同步的或异步的. * 异步处理  异步处 ...

  10. GO_04:GO语言基础条件、跳转、Array和Slice

    1. 判断语句if 1. 条件表达式没有括号(这点其他语言转过来的需要注意) 2. 支持一个初始化表达式(可以是并行方式,即:a, b, c := 1, 2, 3) 3. 左大括号必须和条件语句或 e ...