Pytorch数据读取详解
原文:http://studyai.com/article/11efc2bf#采样器 Sampler & BatchSampler
数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader
from torch.utils.data import *
IMDB + Dataset + Sampler || BatchSampler = DataLoader
数据库 DataBase
Image DataBase 简称IMDB,指的是存储在文件中的数据信息。
文件格式可以多种多样。比如xml, yaml, json, sql.
VOC是xml格式的,COCO是JSON格式的。
构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。
一般会被解析为Python列表, 以方便后续迭代读取。
数据集 DataSet
数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。
换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。
简言之,DataSet,通过__getitem__定义了数据集DataSet是一个可索引对象,An Indexerable Object。
即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。
Pytorch源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
# 定义单例/切片访问方法,即 dataItem = Dataset[index]
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。
# 方法一: 单继承
class XxDataset(Dataset)
# 将IMDB作为参数传入,进行二次封装
imdb = IMDB()
pass
# 方法二: 双继承
class XxDataset(IMDB, Dataset):
pass
采样器 Sampler & BatchSampler
在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,
因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler。
另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler。
所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。
简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制
BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize
Pytorch源码如下,
class Sampler(object):
"""Base class for all Samplers.
采样器基类,可以基于此自定义采样器。
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
# 序惯采样
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
# 随机采样
class RandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(torch.randperm(len(self.data_source)).long())
def __len__(self):
return len(self.data_source)
# 随机子采样
class SubsetRandomSampler(Sampler):
pass
# 加权随机采样
class WeightedRandomSampler(Sampler):
pass
class BatchSampler(object):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler # ******
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。
如 [x for x in range(10)], range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.
[x for x in range(10)]
Out[10]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import SequentialSampler
[x for x in SequentialSampler(range(10))]
Out[14]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
from torch.utils.data.sampler import RandomSampler
[x for x in RandomSampler(range(10))]
Out[12]: [4, 9, 5, 0, 2, 8, 3, 1, 7, 6]
BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。
from torch.utils.data.sampler import BatchSampler
[x for x in BatchSampler(range(10), batch_size=3, drop_last=False)]
Out[9]: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
[x for x in BatchSampler(RandomSampler(range(10)), batch_size=3, drop_last=False)]
Out[15]: [[1, 3, 7], [9, 2, 0], [5, 4, 6], [8]]
加载器 DataLoader
在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,
因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。
因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader。
DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。
该函数可用内置函数iter直接调用,即 DataIteror = iter(DataLoader)。
dataloader = DataLoader(dataset=Dataset(imdb=IMDB()), sampler=Sampler(), num_works, ...)
__init__参数包含两部分,前半部分用于指定数据集 + 采样器,后半部分为多线程参数。
class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
# 检测是否存在参数冲突: 默认batchSampler vs 自定义BatchSampler
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
# 在此处会强行指定一个 BatchSampler
if batch_sampler is None:
# 在此处会强行指定一个 Sampler
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 使用自定义的采样器和批采样器
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
# 调用Pytorch的多线程迭代器加载数据
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
数据迭代器 DataLoaderIter
迭代器与可迭代对象之间是有区别的。
可迭代对象,意思是对其使用Iter函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。
迭代器对象,内部有额外的魔法函数__next__,用内置函数next作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。
可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。

数据集/容器遍历的一般化流程:NILIS
NILIS规则: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))
- sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
- indexer 基于
__item__在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。 - loader 基于
__iter__在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。 - next 基于
__next__在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。
## 初始化
sampler = Sampler()
dataSet = DataSet(sampler) # __getitem__
dataLoader = DataLoader(dataSet, sampler) / DataIterable() # __iter__()
dataIterator = DataLoaderIter(dataLoader) #__next__()
data_iter = iter(dataLoader)
## 遍历方法1
for _ in range(len(data_iter))
data = next(data_iter)
## 遍历方法2
for i, data in enumerate(dataLoader):
data = data
Pytorch数据读取详解的更多相关文章
- hbase实践之数据读取详解
hbase基本存储组织结构与数据读取组织结构对比 Segment是Hbase2.0的概念,MemStore由一个可写的Segment,以及一个或多个不可写的Segments构成.故hbase 1.*版 ...
- Pytorch autograd,backward详解
平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...
- ContentProvider数据访问详解
ContentProvider数据访问详解 Android官方指出的数据存储方式总共有五种:Shared Preferences.网络存储.文件存储.外储存储.SQLite,这些存储方式一般都只是在一 ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- 【HANA系列】SAP HANA XS使用JavaScript数据交互详解
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列]SAP HANA XS使用Jav ...
- JVM 运行时数据区详解
一.运行时数据区 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同数据区域. 1.有一些是随虚拟机的启动而创建,随虚拟机的退出而销毁,所有的线程共享这些数据区. 2.第二种则 ...
- 学习《深度学习与计算机视觉算法原理框架应用》《大数据架构详解从数据获取到深度学习》PDF代码
<深度学习与计算机视觉 算法原理.框架应用>全书共13章,分为2篇,第1篇基础知识,第2篇实例精讲.用通俗易懂的文字表达公式背后的原理,实例部分提供了一些工具,很实用. <大数据架构 ...
- 【HANA系列】【第一篇】SAP HANA XS使用JavaScript数据交互详解
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第一篇]SAP HANA XS ...
- 3dTiles 数据规范详解[1] 介绍
版权:转载请带原地址.https://www.cnblogs.com/onsummer/p/12799366.html @秋意正寒 Web中的三维 html5和webgl技术使得浏览器三维变成了可能. ...
随机推荐
- c++开发遇到的错误和引用配置
1. libcurl引入的时候必须要加载下面三个库 #pragma comment(lib, "ws2_32.lib") #pragma comment(lib, "wl ...
- Qt相关博客总览
一.Qt快速入门 Qt快速入门之一:开始学习Qt 与Qt Creator Qt快速入门之二:Qt Creator简介 Qt快速入门之三:Qt程序编译和源码详解 Qt对话框之一:标准对话框 二.Qt窗口 ...
- 《Linux就该这么学》自学笔记_ch22_使用openstack部署云计算服务环境
<Linux就该这么学>自学笔记_ch22_使用openstackb部署云计算服务环境 文章主要内容: 了解云计算 Openstack项目 服务模块组件详解 安装Openstack软件 使 ...
- PatchMatch小详解
最近发了两片patch match的,其实自己也是有一些一知半解的,找了一篇不知道谁的大论文看了看,又回顾了一下,下面贴我的笔记. The PatchMatch Algorithm patchmatc ...
- CentOS7使用tar.gz包安装MySql的踩坑之旅
由于客户的CentOS服务器没有安装yum工具,只能通过下载tar.gz包安装mysql,于是跟着万能的百度开启了漫漫踩坑之旅: 1.下载mysql-5.6.33-linux-glibc2.5-x86 ...
- 最细的eclipse 安装maven踩过的坑
Eclipse安装maven插件踩过的坑 在线安装maven eclipse安装maven插件,在网上有各种各样的方法,博主使用过的也不止一种,但是留下的印象总是时好时不好,同样的方法也不确定那一次能 ...
- stm32 普通IO口模拟串口通信
普通IO口模拟串口通信 串口通信协议 串口传输 默认 波特率9600 1起始位 1停止位 其他0 数据位是8位(注意图上的给错了). 传输时,从起始位开始,从一个数据的低位(LSB)开始发送,如图从左 ...
- Go语言入门——hello world
Go 语言源代码文件扩展名是.go. 知识点:1. go语言代码的第1行必须声明包2. 入口的go语言代码(包含main函数的代码文件)的包必须是main,否则运行go程序会显示go run: can ...
- Remote System Explorer Operation总是运行后台服务,卡死eclipse解决办法
当你右键编辑控件的id或者其他属性时都会卡很久,发现原来是eclipse后台进程在远程操作,就是右下角显示的“Remote System Explorer Operation”.折腾了半天,在Stac ...
- 数据库权限优化,权限设计BigInteger
最近看到了一个项目的权限是根据bigineger来进行计算的菜单权限,觉得还是不错,存储上只需要存储在一个字段里就可以了,通过计算算出该角色的菜单权限即可,效率也非常的快,放在session中也非常的 ...