Pytorch划分数据集的方法
之前用过sklearn提供的划分数据集的函数,觉得超级方便。但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是“pytorch split dataset”之类的,但是搜出来还是没有我想要的。结果今天见鬼了突然看见了这么一个函数torch.utils.data.Subset。我的天,为什么超级开心hhhh。终于不用每次都手动划分数据集了。
torch.utils.data
Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler
torch的这个文件包含了一些关于数据集处理的类:
- class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
- class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
- class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。
- class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
- class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
- torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
- class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 iter 方-法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
- class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
- class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
- class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
- class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。
- class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。
- class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。
示例
下面Pytorch提供的划分数据集的方法以示例的方式给出:
SubsetRandomSampler
...
dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler)
# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
# Train:
for batch_index, (faces, labels) in enumerate(train_loader):
# ...
random_split
...
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
参考:
- How do I split a custom dataset into training and test datasets?
- PyTorch系列 (二): pytorch数据读取
- pytorch: 自定义数据集加载
Pytorch划分数据集的方法的更多相关文章
- 使用python划分数据集
无论是训练机器学习或是深度学习,第一步当然是先划分数据集啦,今天小白整理了一些划分数据集的方法,希望大佬们多多指教啊,嘻嘻~ 首先看一下数据集的样子,flower_data文件夹下有四个文件夹,每个文 ...
- 【noi 2.6_8787】数的划分(DP){附【转】整数划分的解题方法}
题意:问把整数N分成K份的分法数.(与"放苹果"不同,在这题不可以有一份为空,但可以类比)解法:f[i][j]表示把i分成j份的方案数.f[i][j]=f[i-1][j-1](新开 ...
- Pytorch指定GPU的方法总结
Pytorch指定GPU的方法 改变系统变量 改变系统环境变量仅使目标显卡,编辑 .bashrc文件,添加系统变量 export CUDA_VISIBLE_DEVICES=0 #这里是要使用的GPU编 ...
- sklearn 划分数据集。
1.sklearn.model_selection.train_test_split随机划分训练集和测试集 函数原型: X_train,X_test, y_train, y_test =cross_v ...
- (数据科学学习手札27)sklearn数据集分割方法汇总
一.简介 在现实的机器学习任务中,我们往往是利用搜集到的尽可能多的样本集来输入算法进行训练,以尽可能高的精度为目标,但这里便出现一个问题,一是很多情况下我们不能说搜集到的样本集就能代表真实的全体,其分 ...
- PyTorch 自定义数据集
准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...
- Delphi调用MSSQL存储过程返回的多个数据集的方法
varaintf:_Recordset;RecordsAffected:OleVariant; begin ADOStoredProc1.Close;ADOStoredProc1.Open;aintf ...
- 使用Sklearn-train_test_split 划分数据集
使用sklearn.model_selection.train_test_split可以在数据集上随机划分出一定比例的训练集和测试集 1.使用形式为: from sklearn.model_selec ...
- PyTorch常用参数初始化方法详解
1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...
随机推荐
- 工具(2): 极简MarkDown排版介绍(How to)
如何切换编辑器 切换博客园编辑器为MarkDown:MarkDown Editor 选择一个在线编辑和预览站点:StackEdit 如何排版章节 MarkDown: 大标题 ========== 小标 ...
- promise源码解析
前言 大部分同学对promise,可能还停留在会使用es6的promise,还没有深入学习.我们都知道promise内部通过reslove.reject来判断执行哪个函数,原型上面的then同样的,也 ...
- NodeJs之服务搭建与数据库连接
NodeJs之服务搭建与数据库连接 一,介绍与需求分析 1.1,介绍 Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行环境. Node.js 使用了一个事件驱动.非阻 ...
- 管理者的情商EQ
管理者的情商EQ1 IQ与EQ与AQ: IQ:智慧.逻辑.解决问题 EQ:情感商数.领导团队的热情.互动 AQ:逆商.碰到逆境怎么办.得重大疾病怎么办 成功者的概率: 放弃者:70% 半途而废者:25 ...
- 回忆曾经的SSM框架实现文件上传
近期在使用springboot实现文件上传的功能,想到曾经用SSM做过这个功能,在这里记录一下过去实现的方式 maven添加文件上传所需的依赖 springMVC的配置文件配置一下文件上传 我实现的是 ...
- PHP7中的数据类型
PHP中变量名→zval,变量值→zend_value.其变量内存是通过引用计数管理的,在PHP7中引用计数在value结构中. 变量类型: 头文件在PHP源码 /zend/zend_types.h ...
- Go语言中的Package问题
问题一.Go使用Package组织源码的好处是什么? 1.任何源码属于一个包 2.用包组织便于代码的易读和复用 问题二.Go语言中Package的种类 Go语言中存在两种包.一种是可执行程序的包.一种 ...
- mysql学习基础知识3
1.视图 简化sql语句的编写,限制可以查看的数据 一张虚拟的表,不占任何内存,查视图时都是临时从所查的表中拿数据 特点: 对于视图的增删改查 都会同步到原始表 对原始表的修改,会同步到视图内可查看的 ...
- C语言博客05--指针
C语言博客05--指针 1.本章学习总结 1.1 思维导图 1.2 本章学习体会及代码量学习体会 1.2.1 学习体会 在本周的学习过程中,我们学习了指针的用法.说实话,指针的用法有点绕,之前一直没搞 ...
- thrift使用
一.什么是thrift Thrift是一种接口描述语言和二进制通讯协议,它被用来定义和创建跨语言的服务.它被当作一个远程过程调用(RPC)框架来使用,是由FaceBook为“大规模跨语言服务开发”而开 ...