之前用过sklearn提供的划分数据集的函数,觉得超级方便。但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是“pytorch split dataset”之类的,但是搜出来还是没有我想要的。结果今天见鬼了突然看见了这么一个函数torch.utils.data.Subset。我的天,为什么超级开心hhhh。终于不用每次都手动划分数据集了。

torch.utils.data

Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler

torch的这个文件包含了一些关于数据集处理的类:

  • class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
  • class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
  • class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。
  • class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
  • class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
  • torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
  • class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 iter 方-法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
  • class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
  • class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
  • class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
  • class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。
  • class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。
  • class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

示例

下面Pytorch提供的划分数据集的方法以示例的方式给出:

SubsetRandomSampler

...

dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42 # Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split] # Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler) # Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
# Train:
for batch_index, (faces, labels) in enumerate(train_loader):
# ...

random_split

...

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

参考:

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com


2019-3-8

Pytorch划分数据集的方法的更多相关文章

  1. 使用python划分数据集

    无论是训练机器学习或是深度学习,第一步当然是先划分数据集啦,今天小白整理了一些划分数据集的方法,希望大佬们多多指教啊,嘻嘻~ 首先看一下数据集的样子,flower_data文件夹下有四个文件夹,每个文 ...

  2. 【noi 2.6_8787】数的划分(DP){附【转】整数划分的解题方法}

    题意:问把整数N分成K份的分法数.(与"放苹果"不同,在这题不可以有一份为空,但可以类比)解法:f[i][j]表示把i分成j份的方案数.f[i][j]=f[i-1][j-1](新开 ...

  3. Pytorch指定GPU的方法总结

    Pytorch指定GPU的方法 改变系统变量 改变系统环境变量仅使目标显卡,编辑 .bashrc文件,添加系统变量 export CUDA_VISIBLE_DEVICES=0 #这里是要使用的GPU编 ...

  4. sklearn 划分数据集。

    1.sklearn.model_selection.train_test_split随机划分训练集和测试集 函数原型: X_train,X_test, y_train, y_test =cross_v ...

  5. (数据科学学习手札27)sklearn数据集分割方法汇总

    一.简介 在现实的机器学习任务中,我们往往是利用搜集到的尽可能多的样本集来输入算法进行训练,以尽可能高的精度为目标,但这里便出现一个问题,一是很多情况下我们不能说搜集到的样本集就能代表真实的全体,其分 ...

  6. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

  7. Delphi调用MSSQL存储过程返回的多个数据集的方法

    varaintf:_Recordset;RecordsAffected:OleVariant; begin ADOStoredProc1.Close;ADOStoredProc1.Open;aintf ...

  8. 使用Sklearn-train_test_split 划分数据集

    使用sklearn.model_selection.train_test_split可以在数据集上随机划分出一定比例的训练集和测试集 1.使用形式为: from sklearn.model_selec ...

  9. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

随机推荐

  1. Elastic Stack-Elasticsearch使用介绍(二)

    一.前言     写博客,更要努力写博客! 二.Mapping介绍 Mapping类似于数据库中的表结构的定义:这里我们试想一下表结构定义需要那些: 1.字段和字段类型,在Elasticsearch中 ...

  2. PHP将汉字转为拼音

    没什么难度,最大的难点应该是需要有一个汉字-拼音库. <?php function spell($str, $ishead=0){ $restr = ''; $str = trim($str); ...

  3. Win7删除右键菜单中“图形属性”和“图形选项”

    完win7操作系统后,打完驱动在桌面右键会出现如下两个选项,平时没啥用又占用空间,那么如何删掉这两个选项呢? 操作步骤: 1.在运行中输入 regedit 确定打开注册表: 2.依次单击展开HKEY_ ...

  4. Vue 2.6版本基础知识概要(一)

    挂载组件 //将 App组件挂载到div#app节点里 new Vue({ render: h => h(App), }).$mount('#app') VueComponent.$mount ...

  5. dedecms织梦的不同栏目调用不同banner图的方法

    在做织梦站的时候我们会有不同的栏目,比如联系我们,产品中心等等,banner也不一样,方法如下: 我们可以使用织梦的顶级栏目ID标签,把图片命名成顶级栏目typeid ,代码如下: <img s ...

  6. react-redux的基本用法

    注意:读懂本文需要具备redux基础知识, 注明:本文旨在说明如何在实际项目中快速使用react-redux,限于篇幅,本文对具体的原理并未做分析,请参考redux官网 我一直以为我写了一篇关于rea ...

  7. B-Tree和B+Tree的区别

    B+树索引是B+树在数据库中的一种实现,是最常见也是数据库中使用最为频繁的一种索引.B+树中的B代表平衡(balance),而不是二叉(binary),因为B+树是从最早的平衡二叉树演化而来的.在讲B ...

  8. Verilog语言实现并行(循环冗余码)CRC校验

    1 前言 (1)    什么是CRC校验? CRC即循环冗余校验码:是数据通信领域中最常用的一种查错校验码,其特征是信息字段和校验字段的长度可以任意选定.循环冗余检查(CRC)是一种数据传输检错功能, ...

  9. 【集训队作业2018】取名字太难了 任意模数FFT

    题目大意 求多项式 \(\prod_{i=1}^n(x+i)\) 的系数在模 \(p\) 意义下的分布,对 \(998244353\) 取模. \(p\) 为质数. \(n\leq {10}^{18} ...

  10. Linux keepalived+lvs实现高可用负载均衡

    LVS的具有强大的负载均衡功能,但是它缺少对负载层节点(DS)的健康状态检测功能,也不能对后端服务(RS)进行健康状态检测:keepalived是专门用来监控高可用集群架构的中各服务的节点状态,如果某 ...