【转载】Pytorch tutorial 之Datar Loading and Processing
前言
上文介绍了数据读取、数据转换、批量处理等等。了解到在PyTorch中,数据加载主要有两种方式:
- 1.自定义的数据集对象。数据集对象被抽象为
Dataset类,实现自定义的数据集需要继承Dataset。且须实现__len__()和__getitem__()两个方法。 - 2.利用torchvision包。torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如
torchvision.datasets.CIFAR10来调用。这里介绍ImageFolder,其也继承自Dataset。ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四个参数:
root:在root指定的路径下寻找图片transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象target_transform:对label的转换loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规 范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。
下面我们进一步理解数据读取的内容:
1. 查看dataset实例都有哪些方法与成员:
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/ants&bee_2/', transform=transform)
打印一下dataset类的成员与方法:
print(dir(dataset))
['\_\_add\_\_', '\_\_class\_\_', '\_\_delattr\_\_', '\_\_dict\_\_', '\_\_dir\_\_', '\_\_doc\_\_', '\_\_eq\_\_', '\_\_format\_\_', '\_\_ge\_\_', '\_\_getattribute\_\_', '\_\_getitem\_\_', '\_\_gt\_\_', '\_\_hash\_\_', '\_\_init\_\_', '\_\_le\_\_', '\_\_len\_\_', '\_\_lt\_\_', '\_\_module\_\_', '\_\_ne\_\_', '\_\_new\_\_', '\_\_reduce\_\_', '\_\_reduce\_ex\_\_', '\_\_repr\_\_', '\_\_setattr\_\_', '\_\_sizeof\_\_', '\_\_str\_\_', '\_\_subclasshook\_\_', '\_\_weakref\_\_', 'class\_to\_idx', 'classes', 'imgs', 'loader', 'root', 'target\_transform', 'transform']
dataset.__len__() : 数据集的数目
dataset.__getitem__(idx) : 输入索引,返回对应的图片与标签
dataste.class_to_idx : 字典,类与标签 eg:{‘ants’:0, ‘bees’: 1}
dataset.classes: 列表,返回类别 eg:[ 'ants', 'bees' ]
dataset.imgs : 列表,返回所有图片的路径和对应的label
特别的对于dataset,可以根据dataset.__getitem__(idx)来返回第idx张图与标签,还可以直接进行索引:
dataset[0][1] # 第一维是第几张图,第二维为1返回label
dataset[0][0] # 为0返回图片数据
还可以循环迭代:
for img, label in dataset:
print(img.size(), label)
所以无论是自定义的dataset,或是 ImageFolder得到的dataset,因其都继承自utils.data.Dataset, 故以上方法两种方法都有。
2. Dataloader使用
如果只是每次读取一张图,那么上面的操作已经足够了,但是为了批量操作、打散数据、多进程处理、定制batch,那么我们还需要更高级的类:
DataLoader定义如下:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset:加载的数据集(Dataset对象)
- batch_size:batch size
- shuffle::是否将数据打乱
- sampler: 样本抽样。定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
- batch_sampler(sampler,可选) - 和sampler一样,但一次返回一批索引。与batch_size,shuffle,sampler和drop_last相互排斥。
- num_workers:使用多进程加载的进程数,0代表不使用多进程
- collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
- pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些,默认为false
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,默认为false
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader) # 迭代器
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight # torch.Size([3, 3, 224, 224])
下面主要介绍collate_fn和sampler的用法:
1)collate_fn
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的****collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
eg:
class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
def __getitem__(self, index):
try:
# 调用父类的获取函数,即 DogCat.__getitem__(self, index)
return super(NewDogCat,self).__getitem__(index)
except:
return None, None
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
'''
batch中每个元素形如(data, label)
'''
# 过滤为None的数据
batch = list(filter(lambda x:x[0] is not None, batch))
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
dataset = NewDogCat('data/dogcat_wrong/', transforms=transform)
dataset[5] # (None, None)
第5张图坏掉了所以返回None,下面查看对于批量读取怎么处理:
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1) # 批量为2
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([1, 3, 224, 224]) torch.Size([1])
可以看到第三个批量只有1张图,因为第5张图坏掉了,所以第三个批量只有第六张图。第五个批量也只有1张图是因为数据集总共只有9张图(含坏图)。如果设置drop_last为true,那么第五个批量就被丢弃了。对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:
class NewDogCat(DogCat):
def __getitem__(self, index):
try:
return super(NewDogCat, self).__getitem__(index)
except:
new_index = random.randint(0, len(self)-1)
return self[new_index]
相比较丢弃异常图片而言,这种做法会更好一些,因为它能保证每个batch的数目仍是batch_size。但在大多数情况下,最好的方式还是对数据进行彻底清洗。
2)sampler
sampler模块用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True)
构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。下面举例说明。
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights # [2, 2, 1, 1, 1, 1, 2, 2]
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1]
[1, 0, 1]
[1, 1, 0]
可见猫狗样本比例约为1:2,另外一共只有8个样本,但是却返回了9个,说明肯定有被重复返回的,这就是replacement参数的作用,下面将replacement设为False试试:
sampler = WeightedRandomSampler(weights, 8, replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
[1, 0, 1, 0]
[1, 1, 0, 0]
在这种情况下,num_samples等于dataset的样本总数,为了不重复选取,sampler会将每个样本都返回,这样就失去weight参数的意义了。
从上面的例子可见sampler在样本采样中的作用:如果指定了sampler,shuffle将不再生效,并且sampler.num_samples会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples。
部分转载自:pytorch-book-master
【转载】Pytorch tutorial 之Datar Loading and Processing的更多相关文章
- Pytorch tutorial 之Datar Loading and Processing (1)
引自Pytorch tutorial: Data Loading and Processing Tutorial 这节主要介绍数据的读入与处理. 数据描述:人脸姿态数据集.共有69张人脸,每张人脸都有 ...
- Pytorch tutorial 之Datar Loading and Processing (2)
上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1. 自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset. ...
- pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL
参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...
- [转载]PyTorch上的contiguous
[转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...
- [转载]PyTorch中permute的用法
[转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...
- [转载]Pytorch详解NLLLoss和CrossEntropyLoss
[转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- (转载)XML Tutorial for iOS: How To Read and Write XML Documents with GDataXML
In my recent post on How To Choose the Best XML Parser for Your iPhone Project, Saliom from the comm ...
- (转载)XML Tutorial for iOS: How To Choose The Best XML Parser for Your iPhone Project
There are a lot of options when it comes to parsing XML on the iPhone. The iPhone SDK comes with two ...
随机推荐
- 前端——jQuery
初识jQuery 什么是jQuery? jQuery就是JavaScript和Query,是辅助JavaScript开发的库,应用广泛,形成了行业标准.它对DOM操作做了很好的封装,我们可以用jQue ...
- 基于element ui的级联选择器组件实现的分类后台接口
今天在做资产管理系统的时候遇到一个分类的级联选择器,前端是用的element的组件,需要后台提供接口支持. 这个组件需要传入的数据结构大概是这样的,详细的可参考官方案例: [{ value: ...
- Mybatis实现高级映射一对一、一对多查询
终于把论文写得差不多了,系统也不急着完成,可以抽出点时间来完成这个系列的博客了.在写本博客之前我是惶恐不安的,高级映射一贯是持久层框架里的重中之重,小到自己开发小系统,大到企业级开发,表的存在从来就不 ...
- odoo中各视图写法
透视图: 还需要将一个pivot表添加到要待办任务(To-Do Tasks)中,请使用以下代码: <record id="view_pivot_todo_task" mode ...
- LOJ3053 十二省联考2019 希望 容斥、树形DP、长链剖分
传送门 官方题解其实讲的挺清楚了,就是锅有点多-- 一些有启发性的部分分 L=N 一个经典(反正我是不会)的容斥:最后的答案=对于每个点能够以它作为集合点的方案数-对于每条边能够以其两个端点作为集合点 ...
- CentOS 7 rpm -i 时 警告warning: /var/tmp/rpm-tmp.z7O820: Header V4 RSA/SHA512 Signature, key ID a14fe591: NOKEY 解决方法
这是由于yum安装了旧版本的GPG keys造成的,解决办法就是 运行下面命令即可 # rpm --import /etc/pki/rpm-gpg/RPM* 查询已安装的rpm源 # rpm -qa ...
- ckeditor,关于数据回显
- spring cloud实战与思考(二) 微服务之间通过fiegn上传一组文件(上)
需求场景: 微服务之间调用接口一次性上传多个文件. 上传文件的同时附带其他参数. 多个文件能有效的区分开,以便进行不同处理. Spring cloud的微服务之间接口调用使用Feign.原装的Feig ...
- [NOIp2016] 换教室
题目类型:期望\(DP\) 传送门:>Here< 题意:现有\(N\)个时间段,每个时间段上一节课.如果不申请换教室,那么时间段\(i\)必须去教室\(c[i]\)上课,如果申请换课成功, ...
- [NOI2018]屠龙勇士
题目描述 题解 考虑增量法. 假设我们已经做完了前k个条件,前面的模数连乘起来的结果为M,答案为X,当前的攻击力为x,龙的血量为a. 那么我们这一次的答案的表达形式是X+t*M的. 这一次需要满足的是 ...