前言

上文介绍了数据读取、数据转换、批量处理等等。了解到在PyTorch中,数据加载主要有两种方式:

  • 1.自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset。且须实现__len__()和__getitem__()两个方法。
  • 2.利用torchvision包。torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。这里介绍ImageFolder,其也继承自DatasetImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
   ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:

  • root:在root指定的路径下寻找图片
  • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
  • target_transform:对label的转换
  • loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规 范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。

下面我们进一步理解数据读取的内容:

1.  查看dataset实例都有哪些方法与成员:

from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/ants&bee_2/', transform=transform)

打印一下dataset类的成员与方法:

print(dir(dataset))
['\_\_add\_\_', '\_\_class\_\_', '\_\_delattr\_\_', '\_\_dict\_\_', '\_\_dir\_\_', '\_\_doc\_\_', '\_\_eq\_\_', '\_\_format\_\_', '\_\_ge\_\_', '\_\_getattribute\_\_', '\_\_getitem\_\_', '\_\_gt\_\_', '\_\_hash\_\_', '\_\_init\_\_', '\_\_le\_\_', '\_\_len\_\_', '\_\_lt\_\_', '\_\_module\_\_', '\_\_ne\_\_', '\_\_new\_\_', '\_\_reduce\_\_', '\_\_reduce\_ex\_\_', '\_\_repr\_\_', '\_\_setattr\_\_', '\_\_sizeof\_\_', '\_\_str\_\_', '\_\_subclasshook\_\_', '\_\_weakref\_\_', 'class\_to\_idx', 'classes', 'imgs', 'loader', 'root', 'target\_transform', 'transform']

dataset.__len__()  : 数据集的数目

dataset.__getitem__(idx) : 输入索引,返回对应的图片与标签

dataste.class_to_idx : 字典,类与标签  eg:{‘ants’:0, ‘bees’: 1}

dataset.classes: 列表,返回类别  eg:[ 'ants', 'bees' ]

dataset.imgs : 列表,返回所有图片的路径和对应的label

特别的对于dataset,可以根据dataset.__getitem__(idx)来返回第idx张图与标签,还可以直接进行索引:

dataset[0][1]    # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据

还可以循环迭代:

for img, label in dataset:
print(img.size(), label)

所以无论是自定义的dataset,或是 ImageFolder得到的dataset,因其都继承自utils.data.Dataset, 故以上方法两种方法都有。

2.  Dataloader使用

如果只是每次读取一张图,那么上面的操作已经足够了,但是为了批量操作、打散数据、多进程处理、定制batch,那么我们还需要更高级的类:

DataLoader定义如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle::是否将数据打乱
  • sampler: 样本抽样。定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • batch_sampler(sampler,可选) - 和sampler一样,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些,默认为false
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,默认为false
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False) dataiter = iter(dataloader) # 迭代器 imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight # torch.Size([3, 3, 224, 224])

下面主要介绍collate_fn和sampler的用法:

1)collate_fn

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的****collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

eg:

class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5] # (None, None)

第5张图坏掉了所以返回None,下面查看对于批量读取怎么处理:

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1)     # 批量为2
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])

可以看到第三个批量只有1张图,因为第5张图坏掉了,所以第三个批量只有第六张图。第五个批量也只有1张图是因为数据集总共只有9张图(含坏图)。如果设置drop_last为true,那么第五个批量就被丢弃了。对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:

class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]

相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个batch的数目仍是batch_size。但在大多数情况下,最好的方式还是对数据进行彻底清洗。

2)sampler

sampler模块用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。下面举例说明。

dataset = DogCat('data/dogcat/', transforms=transform)

# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1]
[1, 0, 1]
[1, 1, 0]

可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是replacement参数的作用,下面将replacement设为False试试:

sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1, 0]
[1, 1, 0, 0]

在这种情况下,num_samples等于dataset的样本总数,为了不重复选取,sampler会将每个样本都返回,这样就失去weight参数的意义了。

从上面的例子可见sampler在样本采样中的作用:如果指定了sampler,shuffle将不再生效,并且sampler.num_samples会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples。

部分转载自:pytorch-book-master

MARSGGBO♥原创







2019-3-3

【转载】Pytorch tutorial 之Datar Loading and Processing的更多相关文章

  1. Pytorch tutorial 之Datar Loading and Processing (1)

    引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...

  2. Pytorch tutorial 之Datar Loading and Processing (2)

    上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...

  3. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  4. [转载]PyTorch上的contiguous

    [转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...

  5. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  6. [转载]Pytorch详解NLLLoss和CrossEntropyLoss

    [转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...

  7. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  8. (转载)XML Tutorial for iOS: How To Read and Write XML Documents with GDataXML

    In my recent post on How To Choose the Best XML Parser for Your iPhone Project, Saliom from the comm ...

  9. (转载)XML Tutorial for iOS: How To Choose The Best XML Parser for Your iPhone Project

    There are a lot of options when it comes to parsing XML on the iPhone. The iPhone SDK comes with two ...

随机推荐

  1. [LeetCode] 6. Z 字形变换

    题目链接:(https://leetcode-cn.com/problems/zigzag-conversion/) 题目描述: 将一个给定字符串根据给定的行数,以从上往下.从左到右进行 Z 字形排列 ...

  2. 如何使用 IDEA 创建项目并且上传到 GitHub

    在 GitHub中 注册创建账号 :https://github.com 下载安装 Git : https://git-scm.com 安装成功后打开 Git Bash,输入下列命令,设置 Git 全 ...

  3. redis 初步认识一(下载安装redis)

    1.下载redis  https://github.com/MicrosoftArchive/redis/releases 2.开启redis服务 3.使用redis 4.redis可视化工具 一 开 ...

  4. sqlServer:行列转换之多行转一行

    1.建表:学生表(姓名,学科,成绩) CREATE TABLE teststudent(    stuname varchar(50) NULL,    subjects varchar(50) NU ...

  5. Jmeter名词注解

    取值 ${ip}排除 .*\.js .*\.css .*\.png .*\.gif .*\.msp .*\.js 提取值 (.+?) (.*?)[() 括起来的部分就是需要提取的,对于你要提的内容需要 ...

  6. Educational Codeforces Round 62 (Rated for Div. 2) Solution

    最近省队前联考被杭二成七南外什么的吊锤得布星,拿一场Div. 2恢复信心 然后Div.2 Rk3.Div. 1+Div. 2 Rk9,rating大涨200引起舒适 现在的Div. 2都怎么了,最难题 ...

  7. FJUTOJ-周赛2016-11-25

    注:fjutoj基本每周都有一次周赛,欢迎大家都来参加! 网址:http://59.77.139.92/ 或 acm.fjut.edu.cn A题 题意:一年中,每个月有可能亏x 元,有可能赚y 元, ...

  8. Electron桌面应用打包流程

    一. 准备工作 1.npm的安装需要下载node.js,安装完node.js之后npm自然会有. 参考链接:http://www.runoob.com/nodejs/nodejs-install-se ...

  9. 4月10日java多线程3

    在之前我学习了java中的Thread 来实现多线程,今日我学习了ThreadGroup.Executor框架.ForkJoin框架.Executor 和ForkJoin都可以直接定义线程池,可以根据 ...

  10. jQuery对页面的操作

    一.对元素内容和值进行操作 1.对元素内容操作 [text()]:获取值. [text(val)]:获取并修改值. [html()]:获取值. [html(val)]:获取并修改值,与text的区别在 ...