原文:http://studyai.com/article/11efc2bf#采样器 Sampler & BatchSampler

数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader

from torch.utils.data import *

IMDB + Dataset + Sampler || BatchSampler = DataLoader

数据库 DataBase

Image DataBase 简称IMDB,指的是存储在文件中的数据信息。

文件格式可以多种多样。比如xml, yaml, json, sql.

VOC是xml格式的,COCO是JSON格式的。

构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。

一般会被解析为Python列表, 以方便后续迭代读取。

数据集 DataSet

数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。

换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。

简言之,DataSet,通过__getitem__定义了数据集DataSet是一个可索引对象,An Indexerable Object。

即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。

Pytorch源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
# 定义单例/切片访问方法,即 dataItem = Dataset[index]
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])

自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。

# 方法一: 单继承
class XxDataset(Dataset)
# 将IMDB作为参数传入,进行二次封装
imdb = IMDB()
pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
pass

采样器 Sampler & BatchSampler

在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,

因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler

另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler

所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。

简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制

BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize

Pytorch源码如下,

class Sampler(object):
"""Base class for all Samplers.
采样器基类,可以基于此自定义采样器。
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
pass
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler # ******
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。

[x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.

[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。

from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

加载器 DataLoader

在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,

因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。

因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader

DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。

该函数可用内置函数iter直接调用,即 DataIteror = iter(DataLoader)

dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__参数包含两部分,前半部分用于指定数据集 + 采样器,后半部分为多线程参数

class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)

数据迭代器 DataLoaderIter

迭代器与可迭代对象之间是有区别的。

可迭代对象,意思是对其使用Iter函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。

迭代器对象,内部有额外的魔法函数__next__,用内置函数next作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。

可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。

数据集/容器遍历的一般化流程:NILIS

NILIS规则: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))

  1. sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
  2. indexer 基于__item__在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。
  3. loader 基于__iter__在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。
  4. next 基于__next__在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler) # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__()
dataIterator = DataLoaderIter(dataLoader) #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
data = data

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-8-4

Pytorch数据读取详解的更多相关文章

  1. hbase实践之数据读取详解

    hbase基本存储组织结构与数据读取组织结构对比 Segment是Hbase2.0的概念,MemStore由一个可写的Segment,以及一个或多个不可写的Segments构成.故hbase 1.*版 ...

  2. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  3. ContentProvider数据访问详解

    ContentProvider数据访问详解 Android官方指出的数据存储方式总共有五种:Shared Preferences.网络存储.文件存储.外储存储.SQLite,这些存储方式一般都只是在一 ...

  4. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  5. 【HANA系列】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列]SAP HANA XS使用Jav ...

  6. JVM 运行时数据区详解

    一.运行时数据区 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同数据区域. 1.有一些是随虚拟机的启动而创建,随虚拟机的退出而销毁,所有的线程共享这些数据区. 2.第二种则 ...

  7. 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码

    <深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...

  8. 【HANA系列】【第一篇】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第一篇]SAP HANA XS ...

  9. 3dTiles 数据规范详解[1] 介绍

    版权:转载请带原地址.https://www.cnblogs.com/onsummer/p/12799366.html @秋意正寒 Web中的三维 html5和webgl技术使得浏览器三维变成了可能. ...

随机推荐

  1. dockerfile 的问题 FROM alpine:3.8 temporary error (try again later)

    FROM alpine:3.8 apk add xxx安装软件 fetch http://dl-cdn.alpinelinux.org/alpine/v3.8/main/x86_64/APKINDEX ...

  2. Linux 就该这么学 CH07 使用RAID和LVM磁盘阵列技术

    1 RAID (独立冗余磁盘阵列) RAID 技术通过把多个硬盘设备组合成一个容量更大.安全性更好的磁盘阵列,并把数据切割成多个区段之后分别存在各个不同的物理硬盘设备上,然后利用分散读写计数来提升磁盘 ...

  3. 浅谈设计模式-visitor访问者模式

    先看一个和visitor无关的案例.假设你现在有一个书架,这个书架有两种操作,1添加书籍2阅读每一本书籍的简介. //书架public class Bookcase { List<Book> ...

  4. 【剑指offer】数据流中的中位数

    题目描述 如何得到一个数据流中的中位数?如果从数据流中读出奇数个数值,那么中位数就是所有数值排序之后位于中间的数值.如果从数据流中读出偶数个数值,那么中位数就是所有数值排序之后中间两个数的平均值.我们 ...

  5. netcore3.0 webapi集成Swagger 5.0,Swagger使用

    Swagger使用 1.描述 Swagger 是一个规范和完整的框架,用于生成.描述.调用和可视化 RESTful 风格的 Web 服务. 作用: 1.接口的文档在线自动生成. 2.功能测试 本文转自 ...

  6. 百度前端技术学院task16源代码

    欢迎访问我的github:huanshen 做这道题目的时候遇到了很多困难. 1.怎么给空对象添加数据,愣是不知道从哪里下手:遍历对象,一个个输出操作: 2.中英文的正则表达式不知道,赶紧去百度: 3 ...

  7. Channel延续篇

    上篇文章中介绍了NIO中的Channel,从Channel是什么.特性.分类几个方面做了下简单的介绍.但是后面Channel的分类,个人感觉不够全面,容易误导读者,特此以这篇文章加以补充. Chann ...

  8. php for循环a到z

    首先先介绍2个php内置函数 ord(string):函数返回字符串的首个字符的 ASCII 值.//string:必需.要从中获得 ASCII 值的字符串. chr(ascll): 函数从指定的 A ...

  9. VS2019 Nuget找不到包的问题处理

    VS不记得改了什么设置之后,发现找不到EF 解决办法 1.点击右侧的设置按钮 2.弹出窗中左侧树形结构选择“程序包源”,再点击右上方的添加按钮 输入一下信息:https://www.nuget.org ...

  10. MySQL5.6.17 绿色版 安装配置

    安装篇: 下载完成之后,用解压工具解压到没有中文.空格的文件夹下,解压后的显示如图: 个人建议把解压后的文件夹重命名,如果有中文去掉中文,便于自己理解使用,如图: 打开重命名之后的文件夹,找到mysq ...