原文:http://studyai.com/article/11efc2bf#采样器 Sampler & BatchSampler

数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader

from torch.utils.data import *

IMDB + Dataset + Sampler || BatchSampler = DataLoader

数据库 DataBase

Image DataBase 简称IMDB,指的是存储在文件中的数据信息。

文件格式可以多种多样。比如xml, yaml, json, sql.

VOC是xml格式的,COCO是JSON格式的。

构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。

一般会被解析为Python列表, 以方便后续迭代读取。

数据集 DataSet

数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。

换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。

简言之,DataSet,通过__getitem__定义了数据集DataSet是一个可索引对象,An Indexerable Object。

即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。

Pytorch源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
# 定义单例/切片访问方法,即 dataItem = Dataset[index]
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])

自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。

# 方法一: 单继承
class XxDataset(Dataset)
# 将IMDB作为参数传入,进行二次封装
imdb = IMDB()
pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
pass

采样器 Sampler & BatchSampler

在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,

因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler

另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler

所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。

简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制

BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize

Pytorch源码如下,

class Sampler(object):
"""Base class for all Samplers.
采样器基类,可以基于此自定义采样器。
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
pass
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler # ******
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。

[x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.

[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]

BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。

from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]

加载器 DataLoader

在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,

因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。

因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader

DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。

该函数可用内置函数iter直接调用,即 DataIteror = iter(DataLoader)

dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)

__init__参数包含两部分,前半部分用于指定数据集 + 采样器,后半部分为多线程参数

class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)

数据迭代器 DataLoaderIter

迭代器与可迭代对象之间是有区别的。

可迭代对象,意思是对其使用Iter函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。

迭代器对象,内部有额外的魔法函数__next__,用内置函数next作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。

可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。

数据集/容器遍历的一般化流程:NILIS

NILIS规则: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))

  1. sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
  2. indexer 基于__item__在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。
  3. loader 基于__iter__在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。
  4. next 基于__next__在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler) # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__()
dataIterator = DataLoaderIter(dataLoader) #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
data = data

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-8-4

Pytorch数据读取详解的更多相关文章

  1. hbase实践之数据读取详解

    hbase基本存储组织结构与数据读取组织结构对比 Segment是Hbase2.0的概念,MemStore由一个可写的Segment,以及一个或多个不可写的Segments构成.故hbase 1.*版 ...

  2. Pytorch autograd,backward详解

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...

  3. ContentProvider数据访问详解

    ContentProvider数据访问详解 Android官方指出的数据存储方式总共有五种:Shared Preferences.网络存储.文件存储.外储存储.SQLite,这些存储方式一般都只是在一 ...

  4. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  5. 【HANA系列】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列]SAP HANA XS使用Jav ...

  6. JVM 运行时数据区详解

    一.运行时数据区 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同数据区域. 1.有一些是随虚拟机的启动而创建,随虚拟机的退出而销毁,所有的线程共享这些数据区. 2.第二种则 ...

  7. 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码

    <深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...

  8. 【HANA系列】【第一篇】SAP HANA XS使用JavaScript数据交互详解

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第一篇]SAP HANA XS ...

  9. 3dTiles 数据规范详解[1] 介绍

    版权:转载请带原地址.https://www.cnblogs.com/onsummer/p/12799366.html @秋意正寒 Web中的三维 html5和webgl技术使得浏览器三维变成了可能. ...

随机推荐

  1. Shell脚本之八 函数

    一.函数定义 Linux shell 可以用户定义函数,然后在shell脚本中可以随便调用. shell中函数的定义格式如下: [ function ] funname [()] { action; ...

  2. hive 批量添加,删除分区

    一.批量添加分区:   use bigdata; alter table siebel_member add if not exists partition(dt='20180401') locati ...

  3. Python 中如何判断 list 中是否包含某个元素

    在python中判断 list 中是否包含某个元素: ——可以通过in和not in关键字来判读 例如: abcList=['a','b','c',1,2,3] if 'a' in abcList: ...

  4. css3自定义上传图片输入框的方法

    css3自定义上传图片输入框的方法 代码如下<pre> <form class="form1"> <img src="/kelatoupia ...

  5. Linux之三剑客

    LINUX之三剑客 本篇主要介绍linux下常用的增删改查工具: grep sed awk grep是linux下一个强大的搜索工具,几乎操作linux的用户每天都会或多或少的用到grep命令,单一个 ...

  6. 使用benchmarkSQL测试数据库的TPCC

    压力测试是指在MySQL上线前,需要进行大量的压力测试,从而达到交付的标准.压力测试不仅可以测试MySQL服务的稳定性,还可以测试出MySQL和系统的瓶颈. TPCC测试:Transaction Pr ...

  7. 小白的C++之路——求质数

    初学C++,打算用博客记录学习的足迹.写了两个求质数的程序,修修改改. #include <iostream> #include <math.h> using namespac ...

  8. 『7.5 NOIP模拟赛题解』

    T1 Gift Description ​ 人生赢家老王在网上认识了一个妹纸,然后妹纸的生日到了,为了表示自己的心 意,他决定送她礼物.可是她喜爱的东西特别多,然而他的钱数有限,因此他想 知道当他花一 ...

  9. syntax error near unexpected token 脚本报错误解决

    hadoop老师给了一个shell文件,在windows里面瞅了一眼然后在ubuntu环境下运行就报错了.看了一些博客,用vim -b filename查看的时候发现每一行的末尾都多了一个^M.... ...

  10. 封装:WPF绘制曲线视图

    原文:封装:WPF绘制曲线视图 一.目的:绘制简单轻量级的曲线视图 二.实现: 1.动画加载曲线 2.点击图例显示隐藏对应曲线 3.绘制标准基准线 4.绘制蒙板显示标准区域 曲线图示例: 心电图示例: ...