pytorch中DataLoader, DataSet, Sampler之间的关系
转自:https://mp.weixin.qq.com/s/RTv0cUWvc0kuXBeNoXVu_A
自上而下理解三者关系
首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。
class DataLoader(object):
...
def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter) # Sampler
batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。
那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。
再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。
综上可以知道DataLoader,Sampler和Dataset三者关系如下:

在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。
Sampler
参数传递
要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
可以看到初始化参数里有两种sampler:sampler和batch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSampler将SequentialSampler生成的index按照指定的batch size分组。
>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
Pytorch中已经实现的Sampler有如下几种:
SequentialSamplerRandomSamplerWeightedSamplerSubsetRandomSampler
需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:
如果你自定义了
batch_sampler,那么这些参数都必须使用默认值:batch_size,shuffle,sampler,drop_last.如果你自定义了
sampler,那么shuffle需要设置为False如果
sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:若
shuffle=True,则sampler=RandomSampler(dataset)若
shuffle=False,则sampler=SequentialSampler(dataset)
如何自定义Sampler和BatchSampler?
仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
return len(self.data_source)
所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))。
另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。
Dataset
Dataset定义方式如下:
class Dataset(object):
def __init__(self):
...
def __getitem__(self, index):
return ...
def __len__(self):
return ...
上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:
class DataLoader(object):
...
def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter)
batch = self.collate_fn([self.dataset[i] for i in indices]) # this line
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:
indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)
看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。
pytorch中DataLoader, DataSet, Sampler之间的关系的更多相关文章
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握 ...
- java中paint repaint update 之间的关系
最近总结了一下java中的paint,repaint和updata三者之间的关系,首先咱们都知道用paint方法来绘图,用repaint重绘,用update来写双缓冲.但是他们之间是怎么来调用的呢,咱 ...
- ASP.NET-MVC中Entity和Model之间的关系
Entity 与 Model之间的关系图 ViewModel类是MVC中与浏览器交互的,Entity是后台与数据库交互的,这两者可以在MVC中的model类中转换 MVC基础框架 来自为知笔记(Wiz ...
- pytorch 中Dataloader中的collate_fn参数
一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个b ...
- 讨论SQL语句中主副表之间的关系
在公司这么多些时间,自己在写SQL语句这方面的功夫实在是太差劲了,有时候自己写出来的SQL语句自己都不知道能不能使用,只是自己写出来的SQL语句是不报错的,但是,这对于真正意义上的SQL语句还差的真的 ...
- ado.net中的 sqlconnection sqlcommand datareader dataset SqlDataAdapter之间的关系
Connection: 和数据库交互,必须连接它.连接帮助指明数据库服务器.数据库名字.用户名.密码,和连接数据库所需要的其它参数.Connection对象会被Command对象使用,这样就能够知道是 ...
- FFmpeg 结构体学习(八):FFMPEG中重要结构体之间的关系
FFMPEG中结构体很多.最关键的结构体可以分成以下几类: 解协议(http,rtsp,rtmp,mms) AVIOContext,URLProtocol,URLContext主要存储视音频使用的协议 ...
- python中赋值-浅拷贝-深拷贝之间的关系
赋值: 变量的引用,没有拷贝空间 对象之间赋值本质上 是对象之间的引用传递而已.也就是多个对象指向同一个数据空间. 拷贝的对象分两种类型: . 拷贝可变类型 浅拷贝: 只拷贝第一层数据,不关心里面的第 ...
- JS中BOM和DOM之间的关系
一.Javascript组成JavaScript的实现包括以下3个部分:1.核心(ECMAScript):描述了JS的语法和基本对象.2.文档对象模型 (DOM):处理网页内容的方法和接口.3.浏览器 ...
随机推荐
- 13条必知必会&&测试
1.13条必知必会 <> all(): 查询所有结果 <> filter(**kwargs): 它包含了与所给筛选条件相匹配的对象 <> get(**kwargs) ...
- Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...
- C++ 结构体的定义
struct 结构体名称{ 数据类型 A: 数据类型 B; }结构体变量名; 相当于: struct 结构体名称{ 数据类型 A: 数据类型 B; }; struct 结构体名 ...
- 2012-2013 Northwestern European Regional Contest (NWERC 2012)
B - Beer Pressure \(dp(t, p_1, p_2, p_3, p_4)\)表示总人数为\(t\),\(p_i\)对应酒吧投票人数的概率. 使用滚动数组优化掉一维空间. 总的时间复杂 ...
- 解决Pycharm中SystemError报错
报错描述- 代码逻辑大致是, 开启线程, 监听kafka生产者push的topic消息.- 问题出现在监听过程中, 线程在接收几条topic之后出现报错, 不再处理数据12报错代码Exception ...
- 条件随机场(CRF) - 4 - 学习方法和预测算法(维特比算法)
声明: 1,本篇为个人对<2012.李航.统计学习方法.pdf>的学习总结,不得用作商用,欢迎转载,但请注明出处(即:本帖地址). 2,由于本人在学习初始时有很多数学知识都已忘记,所以为了 ...
- Codeforces Round #187 (Div. 1 + Div. 2)
A. Sereja and Bottles 模拟. B. Sereja and Array 维护全局增量\(Y\),对于操作1(即\(a_{v_i}=x\))操作,改为\(a_{v_i}=x-Y\). ...
- H3C PPP基本概念
- 2018-8-10-VisualStudio-合并代码文件
title author date CreateTime categories VisualStudio 合并代码文件 lindexi 2018-08-10 19:16:52 +0800 2018-2 ...
- CSS3 box-sizing 盒子布局
在CSS中盒模型被分为两种,第一种是W3C的标准模型,第二种是IE怪异盒模型.不同之处在于后者的宽高定义的是可见元素框的尺寸,而不是元素框的内容区尺寸.目前对于浏览器大多数元素都是基于W3C标准的盒模 ...