在进行多卡训练的时候,经常会出现GPU利用率上不来的情况,无法发挥硬件的最大实力。 造成这种现象最有可能的原因是,CPU生成数据的能力,已经跟不上GPU处理数据的能力。


方法一


常见的方法为修改Dataloader里面的线程数量,利用多线程技术提高数据生产能力,但是这种方法提速并不是特别明显。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4)

而且windows机器上,num_worker大于0时,有时会出现卡死的情况,这应该是pytorch的bug,因此不是特别建议这种方法。

不过这种方法最简单,还是可以尝试一下更改线程数能否缓解你遇到的问题。nun_worker一般设置为处理器的物理线程数,不宜过大,因为会导致额外的线程开销。

方法二


本文主要介绍第二种方法,也就是Data Prefetcher,最早见于NVIDIA APEX

这里我把代码抠出来了,删除掉了一些不必要的注释,可以将其复用到自己的项目里来。

import torch

class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload() def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std) def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target

首先我们来看初始化函数,在初始化函数中,会直接调用preload,所以当这个对象初始化时,就会生成第一份的输入数据。

核心逻辑也就在预加载函数preload中,其中第13行是从原来的dataloader中取数,这一步和常规数据加载没有差别。有差别的是第19行,这里出现了Stream的概念。

一般来说,CUDA程序默认都运行在同一个Stream上,因此CPU->GPU,GPU->GPU以及GPU->CPU的一系列计算都是在同一个Stream里面串行运行的。 深度学习一般流程是先从dataloader中取数,这里是内存->CPU的运算,然后执行to_device操作,让数据从CPU->GPU,再是GPU->GPU的神经网络计算。

代码19行,使得data_prefetecher这个类是单独运行在一个Stream上的,因此它让数据加载和神经网络计算可以并行执行,也就加速了整体的运行速度。这样做带来的负面结果就是GPU同时在做两项任务,所以显存占用会增加。

这里不知道解释清楚没有,建议去看一下原作者的回答link

另外,重要的是,使用这个方法的时候一定要将Dataloader里面的pin_memory设置为True。

使用方法如下,非常简单,改造前是从dataloader里取数,改造后是将dataloader包在prefetecher里面,从prefetecher里面取数。

train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4,pin_memory=True)
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next() while input is not None:
##
前后向计算...
###
input, target = prefetcher.next()

Pytorch Dataloader加速的更多相关文章

  1. pytorch :: Dataloader中的迭代器和生成器应用

    在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭 ...

  2. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  3. pytorch dataloader num_workers

    https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5 num_workers 影 ...

  4. pytorch dataloader 取batch_size时候 出现bug

    1.RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 2 ...

  5. PyTorch DataLoader NumberWorkers Deep Learning Speed Limit Increase

    这意味着训练过程将按顺序在主流程中工作. 即:run.num_workers.   ,此外, ,因此,主进程不需要从磁盘读取数据:相反,这些数据已经在内存中准备好了. 这个例子中,我们看到了20%的加 ...

  6. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  7. [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

    [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...

  8. [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

    [源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...

  9. [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer

    [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...

随机推荐

  1. 忘记VMware vcenter的Administrator@vsphere.local密码

    忘记VMware vcenter的Administrator@vsphere.local密码的解决办法一. 重置密码:ssh root@192.168.230.100Connecting to 192 ...

  2. [AcWing 779] 最长公共字符串后缀

    点击查看代码 #include<iostream> using namespace std; const int N = 200; string str[N]; int n ; int m ...

  3. R 数据可视化: PCA 主成分分析图

    简介 主成分分析(Principal Component Analysis,PCA)是一种无监督的数据降维方法,通过主成分分析可以尽可能保留下具备区分性的低维数据特征.主成分分析图能帮助我们直观地感受 ...

  4. 手把手教你在Linux中快速检测端口的 3 个小技巧

    一个执着于技术的公众号 前言 无论是要解决网络连接问题还是配置防火墙,第一件事是要检查系统实际打开了哪些端口. 本文介绍了几种快速查找 Linux 系统上哪些端口向外部开放的方法. 什么是开放端口 监 ...

  5. Centos 7.4_64位系统安装指南

    小土豆Linux学习随笔 -- 清听凌雪慕忆 目录 1. 范围 1.1标识 1.2 文档概述 2. 安装环境 3. 安装步骤 4. 注意事项 1. 范围 1.1标识 CentOS 7.4 64位系统安 ...

  6. 【面试普通人VS高手系列】说说缓存雪崩和缓存穿透的理解,以及如何避免?

    听说10个人去互联网公司面试,有9个人会被问到缓存雪崩和缓存穿透的问题. 听说,这9个人里面,至少有8个人回答得不完整. 而这8个人里面,全都是在网上找的各种面试资料去应付的,并没有真正理解. 当然, ...

  7. Linux的快捷使用(不断更新中)

    Linux 命令行提示符 ~代表当前目录,即家目录,#是超级用户提示符,如果是普通用户使用$ 基本快捷键的使用 移动光标命令 Ctrl+A:移动光标到开头 Ctrl+E:移动光标到结尾 Ctrl+F: ...

  8. React 与 Hooks 如何使用 TypeScript 书写类型?

    React 与 Hooks 如何使用 TypeScript 书写类型? 本文写于 2020 年 9 月 20 日 函数组件与 TS 对于 Hooks 来说是不支持使用 class 组件的. 如何在函数 ...

  9. 23. Merge k Sorted Lists - LeetCode

    Question 23. Merge k Sorted Lists Solution 题目大意:合并链表数组(每个链表中的元素是有序的),要求合并后的链表也是有序的 思路:遍历链表数组,每次取最小节点 ...

  10. 234. Palindrome Linked List - LeetCode

    Question 234. Palindrome Linked List Solution 题目大意:给一个链表,判断是该链表中的元素组成的串是否回文 思路:遍历链表添加到一个list中,再遍历lis ...