暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例.

准备数据文件夹

需要将Image和Mask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像一般为 png 单通道

EGTEA Gaze+ 数据集下载解压后即得到如下的目录, 无需处理

hand14k

┣━ Images

┃ ┣━ OP01-R01-PastaSalad_000014.jpg

┃ ┣━ OP01-R01-PastaSalad_000015.jpg

┃ ┣━ OP01-R01-PastaSalad_000016.jpg

┃ ┗━ ···

┗━ Masks

┣━ OP01-R01-PastaSalad_000014.png

┣━ OP01-R01-PastaSalad_000015.png

┣━ OP01-R01-PastaSalad_000016.png

┗━ ···

生成路径文件, 划分数据集

脚本如下:import cv2 as cv

import numpy as np

import PIL.Image as Image

import os

np.random.seed(42)

def split_dataset():

# 读取图像文件

images_path = "./Images/"

images_list = os.listdir(images_path)  # 每次返回文件列表顺序不一致

images_list.sort()  # 需要排序处理

# 读取标签/Mask图像

labels_path = "./Masks/"

labels_list = os.listdir(labels_path)

labels_list.sort()

# 创建路径文件 (使用二进制编码, 避免操作系统不匹配)

train_file = "./train.data"

test_file = "./test.data"

if os.path.isfile(train_file) and os.path.isfile(test_file):

return

train_file = open(train_file, "wb")

test_file = open(test_file, "wb")

# 外汇返佣

split_ratio = 0.8

for image, label in zip(images_list, labels_list):

image = os.path.join(images_path, image)

label = os.path.join(labels_path, label)

if os.path.basename(image).split('.')[0] != os.path.basename(label).split('.')[0]:

continue

file = train_file if np.random.rand() < split_ratio else test_file

file.write((image + "\t" + label + "\n").encode("utf-8"))

train_file.close()

test_file.close()

print("成功划分数据集!")

def read_image(path):

img = np.array(Image.open(path))

if img.ndim == 2:

img = cv.merge([img, img, img])

return img

def test_read():

train_file = "./test.data"

with open(train_file, 'rb') as f:

datalist = f.readlines()

datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

item = datalist[np.random.randint(42)]

image = read_image(item[0])

mask = read_image(item[1])

cv.imshow("image", image)

cv.imshow("mask", mask)

cv.waitKey(0)

cv.destroyAllWindows()

if __name__ == '__main__':

split_dataset()

test_read()

派生 Dataset 类

class MyDataset(Dataset):

def __init__(

self, data_file, data_dir, transform_trn=None, transform_val=None

):

"""

Args:

data_file (string): Path to the data file with annotations.

data_dir (string): Directory with all the images.

transform_{trn, val} (callable, optional): Optional transform to be applied

on a sample.

"""

with open(data_file, 'rb') as f:

datalist = f.readlines()

self.datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

self.root_dir = data_dir

self.transform_trn = transform_trn

self.transform_val = transform_val

self.stage = 'train'

def set_stage(self, stage):

self.stage = stage

def __len__(self):

return len(self.datalist)

def __getitem__(self, idx):

img_name = os.path.join(self.root_dir, self.datalist[idx][0])

msk_name = os.path.join(self.root_dir, self.datalist[idx][1])

def read_image(x):

img_arr = np.array(Image.open(x))

if len(img_arr.shape) == 2: # grayscale

img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)

return img_arr

image = read_image(img_name)

mask = np.array(Image.open(msk_name))

if img_name != msk_name:

assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'

sample = {'image': image, 'mask': mask}

if self.stage == 'train':

if self.transform_trn:

sample = self.transform_trn(sample)

elif self.stage == 'val':

if self.transform_val:

sample = self.transform_val(sample)

return sample

构造DataLoader

# 定义Transform

composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),

Pad(crop_size, [123.675, 116.28, 103.53], ignore_label),

RandomMirror(),

RandomCrop(crop_size),

Normalise(*normalise_params),

ToTensor()])

composed_val = transforms.Compose([Normalise(*normalise_params),

ToTensor()])

# 导入数据集

trainset = MyDataset(data_file=train_list,

data_dir=train_dir,

transform_trn=composed_trn,

transform_val=composed_val)

valset = MyDataset(data_file=val_list,

data_dir=val_dir,

transform_trn=None,

transform_val=composed_val)

# 构建生成器

train_loader = DataLoader(trainset,

batch_size=batch_size,

shuffle=True,

num_workers=num_workers,

pin_memory=True,

drop_last=True)

val_loader = DataLoader(valset,

batch_size=1,

shuffle=False,

num_workers=num_workers,

pin_memory=True)

训练

for i, sample in enumerate(train_loader):

image = sample['image'].cuda()

target = sample['mask'].cuda()

image_var = torch.autograd.Variable(image).float()

target_var = torch.autograd.Variable(target).long()

# Compute output

output = net(image_var)

...

原文链接:https://blog.csdn.net/Augurlee/article/details/103652444

用于DataLoader的pytorch数据集的更多相关文章

  1. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  2. Pytorch数据集读取

    Pytorch中数据集读取 在机器学习中,有很多形式的数据,我们就以最常用的几种来看: 在Pytorch中,他自带了很多数据集,比如MNIST.CIFAR10等,这些自带的数据集获得和读取十分简便: ...

  3. Pytorch数据集读入——Dataset类,实现数据集打乱Shuffle

    在进行相关平台的练习过程中,由于要自己导入数据集,而导入方法在市面上五花八门,各种库都可以应用,在这个过程中我准备尝试torchvision的库dataset torchvision.datasets ...

  4. [Pytorch数据集下载] 下载MNIST数据缓慢的方案

    步骤一 首先访问下面的网站,手工下载数据集.http://yann.lecun.com/exdb/mnist/ 把四个压缩包下载到任意文件夹,以便之后使用. 步骤二 把自己电脑上已经下载好的数据集的文 ...

  5. PyTorch 之 DataLoader

    DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...

  6. 什么是pytorch(4.数据集加载和处理)(翻译)

    数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...

  7. 【pytorch】torch.utils.data.DataLoader

    简介 DataLoader是PyTorch中的一种数据类型.用于训练/验证/测试时的数据按批读取. torch.utils.data.DataLoader(dataset, batch_size=1, ...

  8. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  9. [实现] 利用 Seq2Seq 预测句子后续字词 (Pytorch)2

    最近有个任务:利用 RNN 进行句子补全,即给定一个不完整的句子,预测其后续的字词.本文使用了 Seq2Seq 模型,输入为 5 个中文字词,输出为 1 个中文字词.目录 关于RNN 语料预处理 搭建 ...

随机推荐

  1. IntelliJ常用配置备忘

    前言 最近IntelliJ又由于自己的骚操作给弄崩溃了,导致之前弄的一大波配置又找不到了,十分蛋疼的又要开始重头开始弄环境.很多之前精心搞过的配置又都记不住了,为了防止以后出现这种情况,这里就把我日常 ...

  2. pve-备份

    一个50g的磁盘,用了13分钟 INFO: starting new backup job: vzdump 111 --node cu-pve04 --mode snapshot --compress ...

  3. MySQL-default设置

    Both statements insert a value into the phone column, but the first inserts a NULL value and the sec ...

  4. Oracle-优化SQL语句

    建议不使用(*)来代替所有列名 用truncate代替delete 在SQL*Plus环境中直接使用truncate table即可:要在PL/SQL中使用,如: 创建一个存储过程,实现使用trunc ...

  5. 让人失望透顶的 CSDN 博客改版

    前言 在 CSDN 写博已经 2 年有余,相比一些大佬,时间不算太长.但工作再忙,我也会保持每月产出,从未间断.每天上线回复评论,勘误内容,参加活动,看看阅读量已经成为一种习惯,可以说是 CSDN 博 ...

  6. Delphi XE2 之 FireMonkey 入门(19) - TFmxObject 的子类们(表)

    参考: 和 FMX 相关的类(表) TFmxObject IFreeNotification             TAnimation TBitmapAnimation           TBi ...

  7. 阶段1 语言基础+高级_1-3-Java语言高级_09-基础加强_第1节 基础加强_4_Junit_@Before&@After

    为了演示输出一段话 测试add的方法 虽然报错了 但是打印的结果还是输出

  8. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_07 缓冲流_1_缓冲流的原理

    一个字节一个字节的读取,先读取到a,a给到os操作系统.os再给JVM,.jVM再把a给java程序 读完a再读取b.这样一层层的返回,效率低下 一次读取,缓冲区数组返回来.

  9. ELK Stack 企业级日志收集平台

    ELK Stack介绍 大型项目,多产品线的日志收集 ,分析平台 为什么用ELK? 1.开发人员排查问题,服务器上查看权限 2.项目多,服务器多,日志类型多 ELK 架构介绍 数据源--->lo ...

  10. dig中文帮助

    NAME(名称)     dig — 发送域名查询信息包到域名服务器 SYNOPSIS(总览)     dig [@server] domain [⟨query-type⟩] [⟨query-clas ...