Pytorch Dataloader加速
在进行多卡训练的时候,经常会出现GPU利用率上不来的情况,无法发挥硬件的最大实力。 造成这种现象最有可能的原因是,CPU生成数据的能力,已经跟不上GPU处理数据的能力。
方法一
常见的方法为修改Dataloader里面的线程数量,利用多线程技术提高数据生产能力,但是这种方法提速并不是特别明显。
train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4)
而且windows机器上,num_worker大于0时,有时会出现卡死的情况,这应该是pytorch的bug,因此不是特别建议这种方法。
不过这种方法最简单,还是可以尝试一下更改线程数能否缓解你遇到的问题。nun_worker一般设置为处理器的物理线程数,不宜过大,因为会导致额外的线程开销。
方法二
本文主要介绍第二种方法,也就是Data Prefetcher,最早见于NVIDIA APEX。
这里我把代码抠出来了,删除掉了一些不必要的注释,可以将其复用到自己的项目里来。
import torch
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
return input, target
首先我们来看初始化函数,在初始化函数中,会直接调用preload,所以当这个对象初始化时,就会生成第一份的输入数据。
核心逻辑也就在预加载函数preload中,其中第13行是从原来的dataloader中取数,这一步和常规数据加载没有差别。有差别的是第19行,这里出现了Stream的概念。
一般来说,CUDA程序默认都运行在同一个Stream上,因此CPU->GPU,GPU->GPU以及GPU->CPU的一系列计算都是在同一个Stream里面串行运行的。 深度学习一般流程是先从dataloader中取数,这里是内存->CPU的运算,然后执行to_device操作,让数据从CPU->GPU,再是GPU->GPU的神经网络计算。
代码19行,使得data_prefetecher这个类是单独运行在一个Stream上的,因此它让数据加载和神经网络计算可以并行执行,也就加速了整体的运行速度。这样做带来的负面结果就是GPU同时在做两项任务,所以显存占用会增加。
这里不知道解释清楚没有,建议去看一下原作者的回答link
另外,重要的是,使用这个方法的时候一定要将Dataloader里面的pin_memory设置为True。
使用方法如下,非常简单,改造前是从dataloader里取数,改造后是将dataloader包在prefetecher里面,从prefetecher里面取数。
train_loader = DataLoader(dataset, batch_size,shuffle=True, num_worker=4,pin_memory=True)
prefetcher = data_prefetcher(train_loader)
input, target = prefetcher.next()
while input is not None:
##
前后向计算...
###
input, target = prefetcher.next()
Pytorch Dataloader加速的更多相关文章
- pytorch :: Dataloader中的迭代器和生成器应用
在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭 ...
- [Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
- pytorch dataloader num_workers
https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5 num_workers 影 ...
- pytorch dataloader 取batch_size时候 出现bug
1.RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 2 ...
- PyTorch DataLoader NumberWorkers Deep Learning Speed Limit Increase
这意味着训练过程将按顺序在主流程中工作. 即:run.num_workers. ,此外, ,因此,主进程不需要从磁盘读取数据:相反,这些数据已经在内存中准备好了. 这个例子中,我们看到了20%的加 ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
- [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化
[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化 目录 [源码解析] PyTorch 分布式(9) ----- DistributedD ...
- [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构
[源码解析] PyTorch 分布式(10)------DistributedDataParallel之Reducer静态架构 目录 [源码解析] PyTorch 分布式(10)------Distr ...
- [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer
[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer 目录 [源码解析] PyTorch 分布式(11) ----- Dis ...
随机推荐
- 忘记VMware vcenter的Administrator@vsphere.local密码
忘记VMware vcenter的Administrator@vsphere.local密码的解决办法一. 重置密码:ssh root@192.168.230.100Connecting to 192 ...
- [AcWing 779] 最长公共字符串后缀
点击查看代码 #include<iostream> using namespace std; const int N = 200; string str[N]; int n ; int m ...
- R 数据可视化: PCA 主成分分析图
简介 主成分分析(Principal Component Analysis,PCA)是一种无监督的数据降维方法,通过主成分分析可以尽可能保留下具备区分性的低维数据特征.主成分分析图能帮助我们直观地感受 ...
- 手把手教你在Linux中快速检测端口的 3 个小技巧
一个执着于技术的公众号 前言 无论是要解决网络连接问题还是配置防火墙,第一件事是要检查系统实际打开了哪些端口. 本文介绍了几种快速查找 Linux 系统上哪些端口向外部开放的方法. 什么是开放端口 监 ...
- Centos 7.4_64位系统安装指南
小土豆Linux学习随笔 -- 清听凌雪慕忆 目录 1. 范围 1.1标识 1.2 文档概述 2. 安装环境 3. 安装步骤 4. 注意事项 1. 范围 1.1标识 CentOS 7.4 64位系统安 ...
- 【面试普通人VS高手系列】说说缓存雪崩和缓存穿透的理解,以及如何避免?
听说10个人去互联网公司面试,有9个人会被问到缓存雪崩和缓存穿透的问题. 听说,这9个人里面,至少有8个人回答得不完整. 而这8个人里面,全都是在网上找的各种面试资料去应付的,并没有真正理解. 当然, ...
- Linux的快捷使用(不断更新中)
Linux 命令行提示符 ~代表当前目录,即家目录,#是超级用户提示符,如果是普通用户使用$ 基本快捷键的使用 移动光标命令 Ctrl+A:移动光标到开头 Ctrl+E:移动光标到结尾 Ctrl+F: ...
- React 与 Hooks 如何使用 TypeScript 书写类型?
React 与 Hooks 如何使用 TypeScript 书写类型? 本文写于 2020 年 9 月 20 日 函数组件与 TS 对于 Hooks 来说是不支持使用 class 组件的. 如何在函数 ...
- 23. Merge k Sorted Lists - LeetCode
Question 23. Merge k Sorted Lists Solution 题目大意:合并链表数组(每个链表中的元素是有序的),要求合并后的链表也是有序的 思路:遍历链表数组,每次取最小节点 ...
- 234. Palindrome Linked List - LeetCode
Question 234. Palindrome Linked List Solution 题目大意:给一个链表,判断是该链表中的元素组成的串是否回文 思路:遍历链表添加到一个list中,再遍历lis ...