Pytorch技法:继承Subset类完成自定义数据拆分
我们在《torch.utils.data.DataLoader与迭代器转换》中介绍了如何使用Pytorch内置的数据集进行论文实现,如torchvision.datasets
。下面是加载内置训练数据集的常见操作:
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize
RAW_DATA_PATH = './rawdata'
transform = Compose(
[ToTensor(),
Normalize((0.1307,), (0.3081,))
]
)
train_data = FashionMNIST(
root=RAW_DATA_PATH,
download=True,
train=True,
transform=transform
)
这里的train_data
做为dataset
对象,它拥有许多熟悉,我们可以通过以下方法获取样本数据的分类类别集合、样本的特征维度、样本的标签集合等信息。
classes = train_data.classes
num_features = train_data.data[0].shape[0]
train_labels = train_data.targets
print(classes)
print(num_features)
print(train_labels)
输出如下:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])
但是,我们常常会在训练集的基础上拆分出验证集(或者只用部分数据来进行训练)。我们想到的第一个方法是使用torch.utils.data.random_split
对dataset
进行划分,下面我们假设划分10000个样本做为训练集,其余样本做为验证集:
from torch.utils.data import random_split
k = 10000
train_data, valid_data = random_split(train_data, [k, len(train_data)-k])
注意我们如果打印train_data
和valid_data
的类型,可以看到显示:
<class 'torch.utils.data.dataset.Subset'>
已经不再是torchvision.datasets.mnist.FashionMNIST
对象,而是一个所谓的Subset
对象!此时Subset
对象虽然仍然还存有data
属性,但是内置的target
和classes
属性已经不复存在,比如如果我们强行访问valid_data
的target
属性:
valid_target = valid_data.target
就会报如下错误:
'Subset' object has no attribute 'target'
但如果我们在后续的代码中常常会将拆分后的数据集也默认为dataset
对象,那么该如何做到代码的一致性呢?
这里有一个trick,那就是以继承SubSet
类的方式的方式定义一个新的CustomSubSet
类,使新类在保持SubSet
类的基本属性的基础上,拥有和原本数据集类相似的属性,如targets
和classes
等:
from torch.utils.data import Subset
class CustomSubset(Subset):
'''A custom subset class'''
def __init__(self, dataset, indices):
super().__init__(dataset, indices)
self.targets = dataset.targets # 保留targets属性
self.classes = dataset.classes # 保留classes属性
def __getitem__(self, idx): #同时支持索引访问操作
x, y = self.dataset[self.indices[idx]]
return x, y
def __len__(self): # 同时支持取长度操作
return len(self.indices)
然后就引出了第二种划分方法,即通过初始化CustomSubset
对象的方式直接对数据集进行划分(这里为了简化省略了shuffle的步骤):
import numpy as np
from copy import deepcopy
origin_data = deepcopy(train_data)
train_data = CustomSubset(origin_data, np.arange(k))
valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)
注意,CustomSubset
类的初始化方法的第二个参数indices
为样本索引,我们可以通过np.arange()
的方法来创建。
然后,我们再访问valid_data
对应的classes
和targes
属性:
print(valid_data.classes)
print(valid_data.targets)
此时,我们发现可以成功访问这些属性了:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 0, 0, ..., 3, 0, 5])
当然,CustomSubset
的作用并不只是添加数据集的属性,我们还可以自定义一些数据预处理操作。我们将类的结构修改如下:
class CustomSubset(Subset):
'''A custom subset class with customizable data transformation'''
def __init__(self, dataset, indices, subset_transform=None):
super().__init__(dataset, indices)
self.targets = dataset.targets
self.classes = dataset.classes
self.subset_transform = subset_transform
def __getitem__(self, idx):
x, y = self.dataset[self.indices[idx]]
if self.subset_transform:
x = self.subset_transform(x)
return x, y
def __len__(self):
return len(self.indices)
我们可以在使用样本前设置好数据预处理算子:
from torchvision import transforms
valid_data.subset_transform = transforms.Compose(\
[transforms.RandomRotation((180,180))])
这样,我们再像下列这样用索引访问取出数据集样本时,就会自动调用算子完成预处理操作:
print(valid_data[0])
打印结果缩略如下:
(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)
Pytorch技法:继承Subset类完成自定义数据拆分的更多相关文章
- [Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
- [深度学习] pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)
一.继承nn.Module类并自定义层 我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类. 首先,简单实现一个Mylinear类: from torch ...
- .Net 配置文件--继承ConfigurationSection实现自定义处理类处理自定义配置节点
除了使用继承IConfigurationSectionHandler的方法定义处理自定义节点的类,还可以通过继承ConfigurationSection类实现同样效果. 首先说下.Net配置文件中一个 ...
- .Net 配置文件——继承ConfigurationSection实现自定义处理类处理自定义配置节点
除了使用继承IConfigurationSectionHandler的方法定义处理自定义节点的类,还可以通过继承ConfigurationSection类实现同样效果. 首先说下.Net配置文件中一个 ...
- WPF 之 创建继承自Window 基类的自定义窗口基类
开发项目时,按照美工的设计其外边框(包括最大化,最小化,关闭等按钮)自然不同于 Window 自身的,但窗口的外边框及窗口移动.最小化等标题栏操作基本都是一样的.所以通过查看资料,可按如下方法创建继承 ...
- QVariant类及QVariant与自定义数据类型转换的方法
这个类型相当于是Java里面的Object,它把绝大多数Qt提供的数据类型都封装起来,起到一个数据类型“擦除”的作用.比如我们的 table单元格可以是string,也可以是int,也可以是一个颜色值 ...
- 【spring boot】7.静态资源和拦截器处理 以及继承WebMvcConfigurerAdapter类进行更多自定义配置
开头是鸡蛋,后面全靠编!!! ======================================================== 1.默认静态资源映射路径以及优先顺序 Spring B ...
- JS面向对象(1) -- 简介,入门,系统常用类,自定义类,constructor,typeof,instanceof,对象在内存中的表现形式
相关链接: JS面向对象(1) -- 简介,入门,系统常用类,自定义类,constructor,typeof,instanceof,对象在内存中的表现形式 JS面向对象(2) -- this的使用,对 ...
- [转]MVC自定义数据验证(两个时间的比较)
本文转自:http://www.cnblogs.com/zhangliangzlee/archive/2012/07/26/2610071.html Model: public class Model ...
随机推荐
- vue组件中的.sync修饰符使用
在vue的组件通信props中,一般情况下,数据都是单向的,子组件不会更改父组件的值,那么vue提供.sync作为双向传递的关键字,实现了父组件的变动会传递给子组件,而子组件的carts改变时,通过事 ...
- gitlab新增ssh
https://blog.csdn.net/u011925641/article/details/79897517
- Web发送邮件
1.首先注册一个163邮箱 自己的邮箱地址是xyqq769552629@163.com 登陆的密码是自己设定 使用邮箱发邮件,邮件必须开启pop和smtp服务,登陆邮件 开启SMTP服务,这个时候提示 ...
- 开启mysql外部访问(root外连)
MySQL外部访问 mysql 默认是禁止远程连接的,你在安装mysql的系统行运行mysql -u root -p 后进入mysql 输入如下: mysql>use mysql; mysql& ...
- Java反射详解:入门+使用+原理+应用场景
反射非常强大和有用,现在市面上绝大部分框架(spring.mybatis.rocketmq等等)中都有反射的影子,反射机制在框架设计中占有举足轻重的作用. 所以,在你Java进阶的道路上,你需要掌握好 ...
- 你的Kubernetes Java应用优雅停机了吗?
Java 应用优雅停机 我们首先考虑下,一般在什么场景下数据会丢失呢? 升级服务时 pod重启时 服务器断电时 因为服务器断电属于极端情况,我们暂且不考虑.那就只有 Java 退出时我们要保证数据的完 ...
- 小白也能看懂的Redis教学基础篇——做一个时间窗限流就是这么简单
不知道ZSet(有序集合)的看官们,可以翻阅我的上一篇文章: 小白也能看懂的REDIS教学基础篇--朋友面试被SKIPLIST跳跃表拦住了 书接上回,话说我朋友小A童鞋,终于面世通过加入了一家公司.这 ...
- Java的八大基本数据类型
Java的八大基本数据类型 前言 Bit是计算机存储数据的基本单元,bit叫做位,也被称作比特位. Byte意为字节,1Byte=1字节,一字节可以存储八个二进制位的数字, 即为1Byte=8bit. ...
- 解决windows下因为防火墙无法通过go get 下载gin的问题
使用: go get -u github.com/gin-gonic/gin 出现以下错误: unrecognized import path "gopkg.in/yaml.v2" ...
- Linux下进程线程,Nignx与php-fpm的进程线程方式
1.进程与线程区别 进程是程序执行时的一个实例,即它是程序已经执行到课中程度的数据结构的汇集.从内核的观点看,进程的目的就是担当分配系统资源(CPU时间.内存等)的基本单位. 线程是进程的一个执行流, ...