DataLoader

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,
batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,
drop_last=False,timeout=0,work_init_fn=None)

常用参数说明:

  • dataset: Dataset类 ( 详见下文数据集构建 ),可以自定义数据集或者读取pytorch自带数据集

  • batch_size: 每个batch加载多少个样本, 默认1

  • shuffle: 是否顺序读取,True表示随机打乱,默认False

  • sampler:定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。

  • batch_sampler: 定义一个按照batch_size大小返回索引的采样器。采样器详见下文Batch_Sampler

  • num_workers: 数据读取进程数量, 默认0

  • collate_fn: 自定义一个函数,接收一个batch的数据,进行自定义处理,然后返回处理后这个batch的数据。例如改变数据类型:

def my_collate_fn(batch_data):
x_batch = []
y_batch = []
for x,y in batch_data:
x_batch.append(x.float())
y_batch.append(y.int())
return x_batch,y_batch
  • pin_memory:设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。默认为False.

    主机中的内存,有两种,一种是锁页,一种是不锁页。锁页内存存放的内容在任何情况下都不会与主机的虚拟内存 (硬盘)进行交换,而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。注意显卡中的显存全部都是锁业内存。如果计算机内存充足的话,设置为True可以加快数据交换顺序。

  • drop_last:默认False, 最后剩余数据量不够batch_size时候,是否丢弃。

  • timeout: 设置数据读取的时间限制,超过限制时间还未完成数据读取则报错。数值必须大于等于0

数据集构建

自定义数据集

自定义数据集,需要继承torch.utils.data.Dataset,然后在__getitem__()中,接受一个索引,返回一个样本, 基本流程,首先在__init__()加载数据以及做一些处理,在__getitem__()中返回单个数据样本,在__len__() 中,返回样本数量

import torch
import torch.utils.data.dataset as Data class MyDataset(Data.Dataset): def __init__(self):
self.x = torch.randn((10,20))
self.y = torch.tensor([1 if i>5 else 0 for i in range(10)],
dtype=torch.long) def __getitem__(self,idx):
return self.x[idx],self.y[idx] def __len__(self):
return self.x.__len__()

torchvision数据集

pytorch自带torchvision库可以帮助我们方便快捷的读取和加载数据

import torch
from torchvision import datasets, transforms # 定义一个预处理方法
transform = transforms.Compose([transforms.ToTensor()])
# 加载一个自带数据集
trainset = datasets.MNIST('/pytorch/MNIST_data/', download=True, train=True,
transform=transform)

TensorDataset

注意这里的tensor必须是一维度的数据。

import torch.utils.data as Data
x = torch.tensor([1,2,3,4,5])
y = torch.tensor([0,0,0,1,1])
dataset = Data.TensorDataset(x,y)

从文件夹中加载数据集

如果想要加载自己的数据集可以这样,用猫狗数据集举例,根目录下 ( "data/train" ),分别放置两个文件夹,dog和cat,这样使用ImageFolder函数就可以自动的将猫狗照片自动的按照文件夹定义为猫狗两个标签

import torch
from torchvision import datasets, transforms data_dir = "data/train"
transform = transforms.Compose([transforms.Resize(255),transforms.ToTensor()]) dataset = datasets.ImageFolder(data_dir, transform=transform)

数据集操作

数据拼接

连接不同的数据集以构成更大的新数据集。

class torch.utils.data.ConcatDataset( [datasets, ... ] )

newDataset = torch.utils.data.ConcatDataset([dataset1,dataset2])

数据切分

方法一: class torch.utils.data.Subset(dataset, indices)

取指定一个索引序列对应的子数据集。

from torch.utils.data import Subset

train_set = Subset(dataset,[i for i in range(1,100)]
test_set = Subset(test0_ds,[i for i in range(100,150)]

方法二:torch.utils.data.random_split(dataset, lengths)

from torch.utils.data import random_split
train_set, test_set = random_split(dataset,[100,50])

采样器

所有采样器都在 torch.utils.data 中,采样器会根据该有的策略返回一组索引,在DataLoader中设定了采样器之后,会根据索引读取相应的样本, 不同采样器生成的索引不一样,从而实现不同的采样目的。

Sampler

所有采样器的基类,自定义采样器的时候需要实现 __iter__() 函数

class Sampler(object):
"""
Base class for all Samplers.
""" def __init__(self, data_source):
pass def __iter__(self):
raise NotImplementedError

RandomSampler

RandomSampler,当DataLoader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。

SequentialSampler

按顺序采样,当DataLoader的shuffle参数为False时,使用的就是SequentialSampler。

SubsetRandomSampler

输入一个列表,按照这个列表采样。也可以通过这个采样器来分割数据集。

BatchSampler

参数:sampler, batch_size, drop_last

每此返回batch_size数量的采样索引,通过设置sampler参数来使用不同的采样方法。

WeightedRandomSampler

参数:weights, num_samples, replacement

它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。通过weights 设定样本权重,权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。num_samples 为返回索引的数量,replacement表示是否是放回抽样,如果为True,表示可以重复采样,默认为True

自定义采样器

集成Sampler类,然后实现__iter__() 方法,比如,下面实现一个SequentialSampler类

class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
""" def __init__(self, data_source):
self.data_source = data_source def __iter__(self):
return iter(range(len(self.data_source))) def __len__(self):
return len(self.data_source)

Pytorch系列:(二)数据加载的更多相关文章

  1. [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...

  2. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  3. QT自定义控件系列(二) --- Loading加载动画控件

    本系列主要使用Qt painter来实现一些基础控件.主要是对平时自行编写的一些自定义控件的总结. 为了简洁.低耦合,我们尽量不使用图片,qrc,ui等文件,而只使用c++的.h和.cpp文件. 由于 ...

  4. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  5. Android4.0图库Gallery2代码分析(二) 数据管理和数据加载

    Android4.0图库Gallery2代码分析(二) 数据管理和数据加载 2012-09-07 11:19 8152人阅读 评论(12) 收藏 举报 代码分析android相册优化工作 Androi ...

  6. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  7. pytorch数据加载

    一.方法一数据组织形式dataset_name----train----val from torchvision import datasets, models, transforms # Data ...

  8. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...

  9. MPP 二、Greenplum数据加载

    Loading external data into greenplum database table using different ways... Greenplum 有常规的COPY加载方法,有 ...

  10. apache ignite系列(三):数据处理(数据加载,数据并置,数据查询)

    ​ 使用ignite的一个常见思路就是将现有的关系型数据库中的数据导入到ignite中,然后直接使用ignite中的数据,相当于将ignite作为一个缓存服务,当然ignite的功能远不止于此,下面以 ...

随机推荐

  1. Vue使用&nbsp空白占位符

    当有时候需要在页面显示时显示空格时,可以使用 ,但是使用这个占位符时,无论写多少个,就只能显示一个空格.要想显示多个空格进行占位,这种方式显然是可行的,解决方法是使用转义字符. 先看代码: <t ...

  2. 可以设置过期时间的Java缓存Map

    前言 最近项目需求需要一个类似于redis可以设置过期时间的K,V存储方式.项目前期暂时不引进redis,暂时用java内存代替. 解决方案 1. ExpiringMap 功能简介 : 1.可设置Ma ...

  3. 将VMware虚拟机最小化到托盘栏

    版权:本文采用「署名-非商业性使用-相同方式共享 4.0 国际」知识共享许可协议进行许可.   目录 前言 将VMware最小化到托盘栏的方法 1.下载 Trayconizer 2.解压 trayco ...

  4. 【HTB系列】靶机Netmon的渗透测试

    出品|MS08067实验室(www.ms08067.com) 本文作者:是大方子(Ms08067实验室核心成员) 总结和反思: win中执行powershell的远程代码下载执行注意双引号转义 对po ...

  5. 剑指 Offer 55 - II. 平衡二叉树 + 平衡二叉树(AVL)的判断

    剑指 Offer 55 - II. 平衡二叉树 Offer_55_2 题目描述 方法一:使用后序遍历+边遍历边判断 package com.walegarrett.offer; /** * @Auth ...

  6. CentOS7安装 xmlsec1 编译并运行官方示例

    1. 自动安装下列软件和依赖(默认已安装libxml2和libxslt) yum install xmlsec1-openssl xmlsec1-openssl-devel 2. 查看官网 www.a ...

  7. [MongoDB知识体系] 一文全面总结MongoDB知识体系

    MongoDB教程 - Mongo知识体系详解 本系列将给大家构建MongoDB全局知识体系.@pdai MongoDB教程 - Mongo知识体系详解 知识体系 学习要点 学习资料 官网资料 入门系 ...

  8. Java流程控制:三种基本结构

    顺序结构: Java的基本结构就是顺序结构,除非特别指明,否则就按照顺序一句一句执行顺序结构是最简单的算法结构语句与语句之间,框与框之间是按从上到下的顺序进行的,它是由若干个依次执行的处理步骤组成的, ...

  9. golang 实现两数组对应元素相除

    func ArrayDivision(arr1 []float64,arr2 []float64) (arr3 []float64) { //两数组对应元素相除 for p:=0;p< len( ...

  10. Python之基础算法介绍

    一.算法介绍 1. 算法是什么 算法是指解题方案的准确而完整的描述,是一系列解决问题的清晰指令,算法代表着用系统的方法描述解决问题的策略机制.也就是说,能够对一定规范的输入,在有限时间内获得所要求的输 ...