用于DataLoader的pytorch数据集
暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例.
准备数据文件夹
需要将Image和Mask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像一般为 png 单通道
EGTEA Gaze+ 数据集下载解压后即得到如下的目录, 无需处理
hand14k
┣━ Images
┃ ┣━ OP01-R01-PastaSalad_000014.jpg
┃ ┣━ OP01-R01-PastaSalad_000015.jpg
┃ ┣━ OP01-R01-PastaSalad_000016.jpg
┃ ┗━ ···
┗━ Masks
┣━ OP01-R01-PastaSalad_000014.png
┣━ OP01-R01-PastaSalad_000015.png
┣━ OP01-R01-PastaSalad_000016.png
┗━ ···
生成路径文件, 划分数据集
脚本如下:import cv2 as cv
import numpy as np
import PIL.Image as Image
import os
np.random.seed(42)
def split_dataset():
# 读取图像文件
images_path = "./Images/"
images_list = os.listdir(images_path) # 每次返回文件列表顺序不一致
images_list.sort() # 需要排序处理
# 读取标签/Mask图像
labels_path = "./Masks/"
labels_list = os.listdir(labels_path)
labels_list.sort()
# 创建路径文件 (使用二进制编码, 避免操作系统不匹配)
train_file = "./train.data"
test_file = "./test.data"
if os.path.isfile(train_file) and os.path.isfile(test_file):
return
train_file = open(train_file, "wb")
test_file = open(test_file, "wb")
# 外汇返佣
split_ratio = 0.8
for image, label in zip(images_list, labels_list):
image = os.path.join(images_path, image)
label = os.path.join(labels_path, label)
if os.path.basename(image).split('.')[0] != os.path.basename(label).split('.')[0]:
continue
file = train_file if np.random.rand() < split_ratio else test_file
file.write((image + "\t" + label + "\n").encode("utf-8"))
train_file.close()
test_file.close()
print("成功划分数据集!")
def read_image(path):
img = np.array(Image.open(path))
if img.ndim == 2:
img = cv.merge([img, img, img])
return img
def test_read():
train_file = "./test.data"
with open(train_file, 'rb') as f:
datalist = f.readlines()
datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]
item = datalist[np.random.randint(42)]
image = read_image(item[0])
mask = read_image(item[1])
cv.imshow("image", image)
cv.imshow("mask", mask)
cv.waitKey(0)
cv.destroyAllWindows()
if __name__ == '__main__':
split_dataset()
test_read()
派生 Dataset 类
class MyDataset(Dataset):
def __init__(
self, data_file, data_dir, transform_trn=None, transform_val=None
):
"""
Args:
data_file (string): Path to the data file with annotations.
data_dir (string): Directory with all the images.
transform_{trn, val} (callable, optional): Optional transform to be applied
on a sample.
"""
with open(data_file, 'rb') as f:
datalist = f.readlines()
self.datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]
self.root_dir = data_dir
self.transform_trn = transform_trn
self.transform_val = transform_val
self.stage = 'train'
def set_stage(self, stage):
self.stage = stage
def __len__(self):
return len(self.datalist)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.datalist[idx][0])
msk_name = os.path.join(self.root_dir, self.datalist[idx][1])
def read_image(x):
img_arr = np.array(Image.open(x))
if len(img_arr.shape) == 2: # grayscale
img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
return img_arr
image = read_image(img_name)
mask = np.array(Image.open(msk_name))
if img_name != msk_name:
assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'
sample = {'image': image, 'mask': mask}
if self.stage == 'train':
if self.transform_trn:
sample = self.transform_trn(sample)
elif self.stage == 'val':
if self.transform_val:
sample = self.transform_val(sample)
return sample
构造DataLoader
# 定义Transform
composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),
Pad(crop_size, [123.675, 116.28, 103.53], ignore_label),
RandomMirror(),
RandomCrop(crop_size),
Normalise(*normalise_params),
ToTensor()])
composed_val = transforms.Compose([Normalise(*normalise_params),
ToTensor()])
# 导入数据集
trainset = MyDataset(data_file=train_list,
data_dir=train_dir,
transform_trn=composed_trn,
transform_val=composed_val)
valset = MyDataset(data_file=val_list,
data_dir=val_dir,
transform_trn=None,
transform_val=composed_val)
# 构建生成器
train_loader = DataLoader(trainset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
val_loader = DataLoader(valset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
训练
for i, sample in enumerate(train_loader):
image = sample['image'].cuda()
target = sample['mask'].cuda()
image_var = torch.autograd.Variable(image).float()
target_var = torch.autograd.Variable(target).long()
# Compute output
output = net(image_var)
...
原文链接:https://blog.csdn.net/Augurlee/article/details/103652444
用于DataLoader的pytorch数据集的更多相关文章
- PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...
- Pytorch数据集读取
Pytorch中数据集读取 在机器学习中,有很多形式的数据,我们就以最常用的几种来看: 在Pytorch中,他自带了很多数据集,比如MNIST.CIFAR10等,这些自带的数据集获得和读取十分简便: ...
- Pytorch数据集读入——Dataset类,实现数据集打乱Shuffle
在进行相关平台的练习过程中,由于要自己导入数据集,而导入方法在市面上五花八门,各种库都可以应用,在这个过程中我准备尝试torchvision的库dataset torchvision.datasets ...
- [Pytorch数据集下载] 下载MNIST数据缓慢的方案
步骤一 首先访问下面的网站,手工下载数据集.http://yann.lecun.com/exdb/mnist/ 把四个压缩包下载到任意文件夹,以便之后使用. 步骤二 把自己电脑上已经下载好的数据集的文 ...
- PyTorch 之 DataLoader
DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...
- 什么是pytorch(4.数据集加载和处理)(翻译)
数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...
- 【pytorch】torch.utils.data.DataLoader
简介 DataLoader是PyTorch中的一种数据类型.用于训练/验证/测试时的数据按批读取. torch.utils.data.DataLoader(dataset, batch_size=1, ...
- pytorch加载语音类自定义数据集
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...
- [实现] 利用 Seq2Seq 预测句子后续字词 (Pytorch)2
最近有个任务:利用 RNN 进行句子补全,即给定一个不完整的句子,预测其后续的字词.本文使用了 Seq2Seq 模型,输入为 5 个中文字词,输出为 1 个中文字词.目录 关于RNN 语料预处理 搭建 ...
随机推荐
- Git 创建版本库并实现本地上传数据到GitHub库
版本库又叫做仓库,其实也是一个目录,这个目录里的所有文件都是被Git管理着,对每个文件的修改,删除,Git都会进行记录,方便我们对其进行跟踪. 因为本地是window环境,我们先从官网下载好windo ...
- Vuex的安装、使用及注意事项
使用Vuex的步骤: (1)安装: 1.使用npm安装: 1 npm install vuex --save 2.使用script标签引入 1 2 3 <script src="/p ...
- 用Vue来实现音乐播放器(十六):滚动列表的实现
滚动列表是一个基础组件 他是基于scroll组件实现的 在base文件夹下面创建一个list-view文件夹 里面有list-view.vue组件 <template> < ...
- 阶段1 语言基础+高级_1-3-Java语言高级_04-集合_04 数据结构_1_数据结构_栈
2.1 数据结构有什么用? 当你用着java里面的容器类很爽的时候,你有没有想过,怎么ArrayList就像一个无限扩充的数组,也好像链表之类 的.好用吗?好用,这就是数据结构的用处,只不过你在不知不 ...
- 四种方法给Vmware虚拟机清理瘦身
随着VMware虚拟机使用时间的增长,其所占用的空间也越来越大,本文来说说怎么给VMware虚拟机占用的空间进行瘦身. **方法一:VMware自带的清理磁盘 **这个方法是VMware自带,具有普适 ...
- Win32InputBox,C接口的,实现类似VB的InputBox的功能
#ifndef __03022006__WIN32INPUTBOX__ #define __03022006__WIN32INPUTBOX__ /* This library is (c) Elias ...
- 2 Vue.js基础
1 简易计算器 <!DOCTYPE html> <html lang="en"> <head> <meta charset="U ...
- TCP通信 - 服务器开启多线程与read()导致服务器阻塞问题
TCP通信的文件上传案例 本地流:客户端和服务器和本地硬盘进行读写,需要使用自己创建的字节流 网络流:客户端和服务器之间读写,必须使用Socket中提供的字节流对象 客户端工作:读取本地文件,上传到服 ...
- C# 反射实现动态加载程序集
原文:https://blog.csdn.net/pengdayong77/article/details/47622235 在.Net 中,程序集(Assembly)中保存了元数据(MetaData ...
- C#传特定的值,获得特定的数组排序
一,在实际业务中,我们会有当我们传任何值进来时,我们要有特定的排序,,比如传进来的是"生物", "历史","化学", 但实际上我们需要的是& ...