前言

上文介绍了数据读取、数据转换、批量处理等等。了解到在PyTorch中,数据加载主要有两种方式:

  • 1.自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset。且须实现__len__()和__getitem__()两个方法。
  • 2.利用torchvision包。torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。这里介绍ImageFolder,其也继承自DatasetImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
   ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:

  • root:在root指定的路径下寻找图片
  • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
  • target_transform:对label的转换
  • loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规 范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。

下面我们进一步理解数据读取的内容:

1.  查看dataset实例都有哪些方法与成员:

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/ants&bee_2/', transform=transform)

打印一下dataset类的成员与方法:

print(dir(dataset))
['\_\_add\_\_', '\_\_class\_\_', '\_\_delattr\_\_', '\_\_dict\_\_', '\_\_dir\_\_', '\_\_doc\_\_', '\_\_eq\_\_', '\_\_format\_\_', '\_\_ge\_\_', '\_\_getattribute\_\_', '\_\_getitem\_\_', '\_\_gt\_\_', '\_\_hash\_\_', '\_\_init\_\_', '\_\_le\_\_', '\_\_len\_\_', '\_\_lt\_\_', '\_\_module\_\_', '\_\_ne\_\_', '\_\_new\_\_', '\_\_reduce\_\_', '\_\_reduce\_ex\_\_', '\_\_repr\_\_', '\_\_setattr\_\_', '\_\_sizeof\_\_', '\_\_str\_\_', '\_\_subclasshook\_\_', '\_\_weakref\_\_', 'class\_to\_idx', 'classes', 'imgs', 'loader', 'root', 'target\_transform', 'transform']

dataset.__len__()  : 数据集的数目

dataset.__getitem__(idx) : 输入索引,返回对应的图片与标签

dataste.class_to_idx : 字典,类与标签  eg:{‘ants’:0, ‘bees’: 1}

dataset.classes: 列表,返回类别  eg:[ 'ants', 'bees' ]

dataset.imgs : 列表,返回所有图片的路径和对应的label

特别的对于dataset,可以根据dataset.__getitem__(idx)来返回第idx张图与标签,还可以直接进行索引:

dataset[0][1]    # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据

还可以循环迭代:

for img, label in dataset:
print(img.size(), label)

所以无论是自定义的dataset,或是 ImageFolder得到的dataset,因其都继承自utils.data.Dataset, 故以上方法两种方法都有。

2.  Dataloader使用

如果只是每次读取一张图,那么上面的操作已经足够了,但是为了批量操作、打散数据、多进程处理、定制batch,那么我们还需要更高级的类:

DataLoader定义如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle::是否将数据打乱
  • sampler: 样本抽样。定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • batch_sampler(sampler,可选) - 和sampler一样,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些,默认为false
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,默认为false
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False) dataiter = iter(dataloader) # 迭代器 imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight # torch.Size([3, 3, 224, 224])

下面主要介绍collate_fn和sampler的用法:

1)collate_fn

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的****collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

eg:

class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5] # (None, None)

第5张图坏掉了所以返回None,下面查看对于批量读取怎么处理:

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1)     # 批量为2
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

可以看到第三个批量只有1张图,因为第5张图坏掉了,所以第三个批量只有第六张图。第五个批量也只有1张图是因为数据集总共只有9张图(含坏图)。如果设置drop_last为true,那么第五个批量就被丢弃了。对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:

class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]

相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个batch的数目仍是batch_size。但在大多数情况下,最好的方式还是对数据进行彻底清洗。

2)sampler

sampler模块用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。下面举例说明。

dataset = DogCat('data/dogcat/', transforms=transform)

# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1]
[1, 0, 1]
[1, 1, 0]

可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是replacement参数的作用,下面将replacement设为False试试:

sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1, 0]
[1, 1, 0, 0]

在这种情况下,num_samples等于dataset的样本总数,为了不重复选取,sampler会将每个样本都返回,这样就失去weight参数的意义了。

从上面的例子可见sampler在样本采样中的作用:如果指定了sampler,shuffle将不再生效,并且sampler.num_samples会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples。

部分转载自:pytorch-book-master

MARSGGBO♥原创







2019-3-3

【转载】Pytorch tutorial 之Datar Loading and Processing的更多相关文章

  1. Pytorch tutorial 之Datar Loading and Processing (1)

    引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...

  2. Pytorch tutorial 之Datar Loading and Processing (2)

    上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...

  3. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  4. [转载]PyTorch上的contiguous

    [转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...

  5. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  6. [转载]Pytorch详解NLLLoss和CrossEntropyLoss

    [转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...

  7. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  8. (转载)XML Tutorial for iOS: How To Read and Write XML Documents with GDataXML

    In my recent post on How To Choose the Best XML Parser for Your iPhone Project, Saliom from the comm ...

  9. (转载)XML Tutorial for iOS: How To Choose The Best XML Parser for Your iPhone Project

    There are a lot of options when it comes to parsing XML on the iPhone. The iPhone SDK comes with two ...

随机推荐

  1. android 获取通话记录

    在manifest添加以下权限<uses-permission android:name="android.permission.READ_CALL_LOG" />&l ...

  2. Html 改变原有标签属性

    内容简要: 当标签内内容 达到某以条件的时候改变当前标签属性 例如原标签为<tr> 当tr内的值符合某一条件时把<tr>变成<a>标签 例:当订单状体编程已支付的时 ...

  3. 装饰器模式以及Laravel框架下的中间件应用

    Laravel框架的中间件使用:从请求进来到响应返回,经过中间件的层层包装,这种场景很适合用到一种设计模式---装饰器模式. 装饰器模式的作用,多种外界因素改变对象的行为.使用继承的方式改变行为不太被 ...

  4. 控制结构(3): 状态机(state machine)

    // 上一篇:卫语句(guard clause) // 下一篇:局部化(localization) 基于语言提供的基本控制结构,更好地组织和表达程序,需要良好的控制结构. 前情回顾 上次分析了guar ...

  5. 用Python开发小学二年级口算自动出题程序

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 武汉光谷一小二年级要求家长每天要给小孩出口算题目,让孩子练习. 根据老师出题要求编写了Python程序 ...

  6. 其它综合-VMware虚拟机安装Ubuntu 19.04 版本

    Ubuntu 19.04 版本安装过程 1. 环境: 使用的虚拟机软件是VMware,版本为 12 .(网上一搜一大推,在此不再演示.) 使用的 ISO镜像为Ubuntu 19.04.(自己也可以在网 ...

  7. 简单实现计算机上多个jdk环境切换

    实现多个jdk环境切换,大致有两种方式 安装两个jdk,并配置相应的环境变量,在java的控制面板中修改设置 非主要的jdk仅仅是用来测试,并不常用,故只要让ide配置对应的jdk位置就可以了,属于懒 ...

  8. python之路7-正则表达式

    正则表达式用于做字符串匹配,在python中用re模块来操作 生成正则的在线工具:http://tool.chinaz.com/regex

  9. IO多路复用和local概念

    一.local 在多个线程之间使用threading.local对象,可以实现多个线程之间的数据隔离 import time import random from threading import T ...

  10. Hive 口袋手册

    2019-04-01 关键字:Hive 学习总结.Hive 基础 . Hive 进阶 .Hive 调优 . Hive 入门手册.Hive PDF 下载 本篇文章系本人就目前所掌握的知识对 Apache ...