上文介绍了数据读取、数据转换、批量处理等等。了解到在PyTorch中,数据加载主要有两种方式:

  • 1. 自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset。且须实现__len__()和__getitem__()两个方法。
  • 2. 利用torchvision包。torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。这里介绍ImageFolder,其也继承自DatasetImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
   ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:

    • root:在root指定的路径下寻找图片
    • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
    • target_transform:对label的转换
    • loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规 范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。

下面我们进一步理解数据读取的内容:

1.  查看dataset实例都有哪些方法与成员:

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/ants&bee_2/', transform=transform)

打印一下dataset类的成员与方法:

print(dir(dataset))

['__add__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'class_to_idx', 'classes', 'imgs', 'loader', 'root', 'target_transform', 'transform']

dataset.__len__()  : 数据集的数目

dataset.__getitem__(idx) : 输入索引,返回对应的图片与标签

dataste.class_to_idx : 字典,类与标签  eg:{‘ants’:0, ‘bees’: 1}

dataset.classes: 列表,返回类别  eg:[ 'ants', 'bees' ]

dataset.imgs : 列表,返回所有图片的路径和对应的label

特别的对于dataset,可以根据dataset.__getitem__(idx)来返回第idx张图与标签,还可以直接进行索引:

dataset[0][1]    # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据

还可以循环迭代:

for img, label in dataset:
print(img.size(), label)

所以无论是自定义的dataset,或是 ImageFolder得到的dataset,因其都继承自utils.data.Dataset, 故以上方法两种方法都有。

2.  Dataloader使用

如果只是每次读取一张图,那么上面的操作已经足够了,但是为了批量操作、打散数据、多进程处理、定制batch,那么我们还需要更高级的类:

DataLoader定义如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle::是否将数据打乱
  • sampler: 样本抽样。定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • batch_sampler(sampler,可选) - 和sampler一样,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些,默认为false
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,默认为false
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader) # 迭代器
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight # torch.Size([3, 3, 224, 224])

下面主要介绍collate_fn和sampler的用法:

1)collate_fn

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

eg:

class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5] # (None, None)

第5张图坏掉了所以返回None,下面查看对于批量读取怎么处理:

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1)     # 批量为2
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

可以看到第三个批量只有1张图,因为第5张图坏掉了,所以第三个批量只有第六张图。第五个批量也只有1张图是因为数据集总共只有9张图(含坏图)。如果设置drop_last为true,那么第五个批量就被丢弃了。对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:

class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]

相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个batch的数目仍是batch_size。但在大多数情况下,最好的方式还是对数据进行彻底清洗。

 2)sampler

sampler模块用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。下面举例说明。

dataset = DogCat('data/dogcat/', transforms=transform)

# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1]
[1, 0, 1]
[1, 1, 0]

可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是replacement参数的作用,下面将replacement设为False试试:

sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1, 0]
[1, 1, 0, 0]

在这种情况下,num_samples等于dataset的样本总数,为了不重复选取,sampler会将每个样本都返回,这样就失去weight参数的意义了。

从上面的例子可见sampler在样本采样中的作用:如果指定了sampler,shuffle将不再生效,并且sampler.num_samples会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples。

部分转载自:pytorch-book-master

Pytorch tutorial 之Datar Loading and Processing (2)的更多相关文章

  1. Pytorch tutorial 之Datar Loading and Processing (1)

    引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...

  2. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  3. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  4. Pytorch tutorial 之Transfer Learning

    引自官方:  Transfer Learning tutorial Ng在Deeplearning.ai中讲过迁移学习适用于任务A.B有相同输入.任务B比任务A有更少的数据.A任务的低级特征有助于任务 ...

  5. pytorch tutorial 2

    这里使用pytorch进行一个简单的二分类模型 导入所有我们需要的库 import torch import matplotlib.pyplot as plt import torch.nn.func ...

  6. Pytorch model saving and loading 模型保存和读取

    It is really useful to save and reload the model and its parameters during or after training in deep ...

  7. pytorch tutorial 1

    这里用torch 做一个最简单的测试 目标就是我们用torch 建立一个一层的网络,然后拟合一组可以回归的数据 import torch from torch.autograd import Vari ...

  8. Pytorch从0开始实现YOLO V3指南 part5——设计输入和输出的流程

    本节翻译自:https://blog.paperspace.com/how-to-implement-a-yolo-v3-object-detector-from-scratch-in-pytorch ...

  9. (转)Awesome PyTorch List

    Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...

随机推荐

  1. Linux报错

    Linux报错 ------------------- 在VMware虚拟机中配置yum源时,执行 mount /dev/cdrom /mnt/cdrom 出现 mount: no medium fo ...

  2. OGNL中的#、%和$符号的用法

    转自:https://blog.csdn.net/qq_24963197/article/details/51773224 一.OGNL中的#.%和$符号 1.#符号的三种用法 1)访问非根对象属性, ...

  3. H5新属性FileReader实现选择图片后立即显示在页面上

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  4. @GetMapping(value="/") , "/" 可加可不加 ,是不是一样的

    @GetMapping(value = "/user") 和  @GetMapping(value = "user") 的区别 1.带上 "/&quo ...

  5. 前台ajax传参数,后台spring mvc用对象接受

    第二种方法:利用spring mvc的机制,调用对象的get方法,要求对象的属性名和传的参数名字一致(有兴趣的同学看 springmvc源码) 1.将参数名直接写成对象的属性名 $.ajax({ ur ...

  6. HDU - 1542 Atlantis(线段树求面积并)

    https://cn.vjudge.net/problem/HDU-1542 题意 求矩形的面积并 分析 点为浮点数,需要离散化处理. 给定一个矩形的左下角坐标和右上角坐标分别为:(x1,y1).(x ...

  7. 属性集合java.util.Properties

    属性集合java.util.Properties java.util.Properties集合 extends Hashtable<k, v> implements Map<k, v ...

  8. 404.17 - 动态内容通过通配符 MIME 映射映射到静态文件处理程序

    刚刚重装了系统,原有的ASP.NET工程下面的WebService无法运行,如下: 404.17 - 动态内容通过通配符 MIME 映射映射到静态文件处理程序 微软的提示,是做三项更改,但是我改了之后 ...

  9. 百度编辑器 ueditor 会屏蔽过滤 body html head DOCTYPE ... 的解决办法

    百度编辑器很强,但有时候复制到html里时,会带有 body  html head 等标签,切到视图时,内容都不见了 是因为白名单 解决办法: 我测的是1.4.3版本 在 ueditor.config ...

  10. Java8新特性 并行流与串行流 Fork Join

    并行流就是把一个内容分成多个数据块,并用不同的线程分 别处理每个数据块的流. Java 8 中将并行进行了优化,我们可以很容易的对数据进行并 行操作. Stream API 可以声明性地通过 para ...