原文:http://studyai.com/article/11efc2bf#采样器 Sampler & BatchSampler

数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader

from torch.utils.data import *

IMDB + Dataset + Sampler || BatchSampler = DataLoader

数据库 DataBase

Image DataBase 简称IMDB,指的是存储在文件中的数据信息。

文件格式可以多种多样。比如xml, yaml, json, sql.

VOC是xml格式的,COCO是JSON格式的。

构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。

一般会被解析为Python列表, 以方便后续迭代读取。

数据集 DataSet

数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。

换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。

简言之,DataSet,通过__getitem__定义了数据集DataSet是一个可索引对象,An Indexerable Object。

即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。

Pytorch源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
# 定义单例/切片访问方法,即 dataItem = Dataset[index]
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])

自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。

# 方法一: 单继承
class XxDataset(Dataset)
# 将IMDB作为参数传入,进行二次封装
imdb = IMDB()
pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
pass

采样器 Sampler & BatchSampler

在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,

因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler

另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler

所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。

简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制

BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize

Pytorch源码如下,

class Sampler(object):
"""Base class for all Samplers.
采样器基类,可以基于此自定义采样器。
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
pass
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler # ******
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。

[x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.

[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。

from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

加载器 DataLoader

在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,

因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。

因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader

DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。

该函数可用内置函数iter直接调用,即 DataIteror = iter(DataLoader)

dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__参数包含两部分,前半部分用于指定数据集 + 采样器,后半部分为多线程参数

class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)

数据迭代器 DataLoaderIter

迭代器与可迭代对象之间是有区别的。

可迭代对象,意思是对其使用Iter函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。

迭代器对象,内部有额外的魔法函数__next__,用内置函数next作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。

可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。

数据集/容器遍历的一般化流程:NILIS

NILIS规则: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))

  1. sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
  2. indexer 基于__item__在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。
  3. loader 基于__iter__在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。
  4. next 基于__next__在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler) # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__()
dataIterator = DataLoaderIter(dataLoader) #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
data = data

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-8-4

Pytorch数据读取详解的更多相关文章

  1. hbase实践之数据读取详解

    hbase基本存储组织结构与数据读取组织结构对比 Segment是Hbase2.0的概念,MemStore由一个可写的Segment,以及一个或多个不可写的Segments构成.故hbase 1.*版 ...

  2. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  3. ContentProvider数据访问详解

    ContentProvider数据访问详解 Android官方指出的数据存储方式总共有五种:Shared Preferences.网络存储.文件存储.外储存储.SQLite,这些存储方式一般都只是在一 ...

  4. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  5. 【HANA系列】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列]SAP HANA XS使用Jav ...

  6. JVM 运行时数据区详解

    一.运行时数据区 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同数据区域. 1.有一些是随虚拟机的启动而创建,随虚拟机的退出而销毁,所有的线程共享这些数据区. 2.第二种则 ...

  7. 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码

    <深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...

  8. 【HANA系列】【第一篇】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第一篇]SAP HANA XS ...

  9. 3dTiles 数据规范详解[1] 介绍

    版权:转载请带原地址.https://www.cnblogs.com/onsummer/p/12799366.html @秋意正寒 Web中的三维 html5和webgl技术使得浏览器三维变成了可能. ...

随机推荐

  1. Spring Security教程(四)

    在前面三个博客的例子中,登陆页面都是用的Spring Security自己提供的,这明显不符合实际开发场景,同时也没有退出和注销按钮,因此在每次测试的时候都要通过关闭浏览器来注销达到清除session ...

  2. centos7 df 命令卡死

    登录服务器想查看磁盘使用情况,使用了df,但卡住半天没有响应. 运行strace df -h,发现最后卡在了 stat("/proc/sys/fs/binfmt_misc", 无法 ...

  3. 镜像仓库 Nexus 3.18.1

    说明:Nexus是Sonatype提供的仓库管理平台,Nuexus Repository OSS3能够支持Maven.npm.Docker.YUM.Helm等格式数据的存储和发布. 一.安装jdk 1 ...

  4. VisualStudio ------- vs发布软件

    上线的系统和自己做的系统有什么区别 上线的没有源代码,没有实体层,数据库访问层  业务逻辑层 只有表现层  而且也也没有    .cs 和 .psd   文件,这样就不能修改系统代码 他们都在 Web ...

  5. intellij idea快速通过mapper跳转到xml文件

    安装完之后重启idea即可!

  6. 类的练习3——python编程从入门到实践

    9-13 使用OrderedDict: 在练习6-4中,使用一个标准字典来表示词汇表.使用OrderedDict类来重写这个程序,并确认输出的顺序与在字典中添加的键值对的顺序一致. from coll ...

  7. TCMalloc - 细节

    1,释放速度控制 在将一个Span删除掉的时候,会优先将它加入到normal队列中,这之后会尝试从normal队列中释放一部分同样大小的内存给系统. 释放内存给系统的时候,tcmalloc使用了一个延 ...

  8. [CF24A]Ring road(2019-11-15考试)

    题目大意 给你一个\(n\)个点的环,每条边有方向,改变第\(i\)条边的方向代价为\(w_i\),问将其改为强连通图的最小代价.\(n\leqslant100\) 题解 求出把边全部改为顺时针和全部 ...

  9. json文件 乱码问题 根本解决办法

    1 工具→自定义:2 点击 命令 标签:3 在上方单选区选中 菜单栏,下拉列表选 文件:4 点击 添加命令5 在类别中,找到文件,在右侧找到高级保存选项,确定6 然后可以通过下移调整该选项在文件菜单中 ...

  10. jq动画插件,自制基于vue的圆形时钟

    首先附上jq插件库,里面的东西太炫了,建议学前端的可以看看学习下:http://www.jq22.com/ 里面有个“超个性动画版本的个人简历”,通过屏幕上不断打印内容,改变相应样式来实现动画简历,我 ...