在进行多卡训练的时候,经常会出现GPU利用率上不来的情况,无法发挥硬件的最大实力。 造成这种现象最有可能的原因是,CPU生成数据的能力,已经跟不上GPU处理数据的能力。


方法一


常见的方法为修改Dataloader里面的线程数量,利用多线程技术提高数据生产能力,但是这种方法提速并不是特别明显。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4)

而且windows机器上,num_worker大于0时,有时会出现卡死的情况,这应该是pytorch的bug,因此不是特别建议这种方法。

不过这种方法最简单,还是可以尝试一下更改线程数能否缓解你遇到的问题。nun_worker一般设置为处理器的物理线程数,不宜过大,因为会导致额外的线程开销。

方法二


本文主要介绍第二种方法,也就是Data Prefetcher,最早见于NVIDIA APEX

这里我把代码抠出来了,删除掉了一些不必要的注释,可以将其复用到自己的项目里来。

import torch

class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload() def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std) def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target

首先我们来看初始化函数,在初始化函数中,会直接调用preload,所以当这个对象初始化时,就会生成第一份的输入数据。

核心逻辑也就在预加载函数preload中,其中第13行是从原来的dataloader中取数,这一步和常规数据加载没有差别。有差别的是第19行,这里出现了Stream的概念。

一般来说,CUDA程序默认都运行在同一个Stream上,因此CPU->GPU,GPU->GPU以及GPU->CPU的一系列计算都是在同一个Stream里面串行运行的。 深度学习一般流程是先从dataloader中取数,这里是内存->CPU的运算,然后执行to_device操作,让数据从CPU->GPU,再是GPU->GPU的神经网络计算。

代码19行,使得data_prefetecher这个类是单独运行在一个Stream上的,因此它让数据加载和神经网络计算可以并行执行,也就加速了整体的运行速度。这样做带来的负面结果就是GPU同时在做两项任务,所以显存占用会增加。

这里不知道解释清楚没有,建议去看一下原作者的回答link

另外,重要的是,使用这个方法的时候一定要将Dataloader里面的pin_memory设置为True。

使用方法如下,非常简单,改造前是从dataloader里取数,改造后是将dataloader包在prefetecher里面,从prefetecher里面取数。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4,pin_memory=True)
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next() while input is not None:
##
前后向计算...
###
input, target = prefetcher.next()

Pytorch Dataloader加速的更多相关文章

  1. pytorch :: Dataloader中的迭代器和生成器应用

    在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭 ...

  2. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  3. pytorch dataloader num_workers

    https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5 num_workers 影 ...

  4. pytorch dataloader 取batch_size时候 出现bug

    1.RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 2 ...

  5. PyTorch DataLoader NumberWorkers Deep Learning Speed Limit Increase

    这意味着训练过程将按顺序在主流程中工作. 即:run.num_workers.   ,此外, ,因此,主进程不需要从磁盘读取数据:相反,这些数据已经在内存中准备好了. 这个例子中,我们看到了20%的加 ...

  6. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  7. [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

    [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...

  8. [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

    [源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...

  9. [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer

    [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...

随机推荐

  1. Java函数的学习

    函数的定义 - 定义的位置:定义在类的内部 - 组成部分: 函数修饰符 类型 函数名(形式参数){ 局部变量: 注释: 函数体: } 函数的调用 - 调用函数时使用 : `函数名():` - 函数在执 ...

  2. 尤娜故事-迷雾-springboot扮酷小技巧

    前情回顾 从前,有一个简单的通道系统叫尤娜-- 尤娜系统的第一次飞行中换引擎的架构垂直拆分改造 四种常用的微服务架构拆分方式 尤娜,我去面试了 正文 我回到日常的尤娜系统建设中,最近事情比较少,总有一 ...

  3. js console.log打印变量注意事项

    如果是基本类型变量是没有异常的 let str = 'string' console.log(str) // string str = '改变了str变量' 如果是引用类型,打印就要注意了 let o ...

  4. Python技法:用re模块实现简易tokenizer

    一个简单的tokenizer 分词(tokenization)任务是Python字符串处理中最为常见任务了.我们这里讲解用正则表达式构建简单的表达式分词器(tokenizer),它能够将表达式字符串从 ...

  5. 一文带你读懂什么是vxlan网络

    一个执着于技术的公众号 一.背景 随着云计算.虚拟化相关技术的发展,传统网络无法满足大规模.灵活性要求高的云数据中心的要求,于是便有了overlay网络的概念.overlay网络中被广泛应用的就是vx ...

  6. 手脱MoleBox(2.3.3-2.6.4)

    1.查壳 2.找到OEP 对第二个Call使用ESP定律,再跳转后的位置进入第一个Call,这里就是OEP了,在这里直接dump的话会失败,那是因为MoleBox壳对IAT进行二次跳转,我们先在OEP ...

  7. 没想到吧!这个可可爱爱的游戏居然是用 ECharts 实现的!

    摘要:echarts 是一个很强大的图表库,除了我们常见的图表功能,还可以自定义图形,这个功能让我们可以很简单地在画布上绘制一些非常规的图形,基于此,我们来玩一些花哨的:做一个 Flappy Bird ...

  8. Ruby 趣学笔记(一)

    Ruby 趣学笔记(一) 本文写于 2020 年 5 月 6 日 Ruby 趣学笔记(一) 变量 变量声明 变量类型 常量 输出 字符串 字符串操作 Array 数组的遍历 数组的连接 怎么判断该变量 ...

  9. C# WPF后台动态添加控件(经典)

    概述 在Winform中从后台添加控件相对比较容易,但是在WPF中,我们知道界面是通过XAML编写的,如何把后台写好的控件动态添加到前台呢?本节举例介绍这个问题. 这里要用到UniformGrid布局 ...

  10. java并发编程-StampedLock高性能读写锁

    目录 一.读写锁 二.悲观读锁 三.乐观读 欢迎关注我的博客,更多精品知识合集 一.读写锁 在我的<java并发编程>上一篇文章中为大家介绍了<ReentrantLock读写锁> ...