暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例.

准备数据文件夹

需要将Image和Mask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像一般为 png 单通道

EGTEA Gaze+ 数据集下载解压后即得到如下的目录, 无需处理

hand14k

┣━ Images

┃ ┣━ OP01-R01-PastaSalad_000014.jpg

┃ ┣━ OP01-R01-PastaSalad_000015.jpg

┃ ┣━ OP01-R01-PastaSalad_000016.jpg

┃ ┗━ ···

┗━ Masks

┣━ OP01-R01-PastaSalad_000014.png

┣━ OP01-R01-PastaSalad_000015.png

┣━ OP01-R01-PastaSalad_000016.png

┗━ ···

生成路径文件, 划分数据集

脚本如下:import cv2 as cv

import numpy as np

import PIL.Image as Image

import os

np.random.seed(42)

def split_dataset():

# 读取图像文件

images_path = "./Images/"

images_list = os.listdir(images_path)  # 每次返回文件列表顺序不一致

images_list.sort()  # 需要排序处理

# 读取标签/Mask图像

labels_path = "./Masks/"

labels_list = os.listdir(labels_path)

labels_list.sort()

# 创建路径文件 (使用二进制编码, 避免操作系统不匹配)

train_file = "./train.data"

test_file = "./test.data"

if os.path.isfile(train_file) and os.path.isfile(test_file):

return

train_file = open(train_file, "wb")

test_file = open(test_file, "wb")

# 外汇返佣

split_ratio = 0.8

for image, label in zip(images_list, labels_list):

image = os.path.join(images_path, image)

label = os.path.join(labels_path, label)

if os.path.basename(image).split('.')[0] != os.path.basename(label).split('.')[0]:

continue

file = train_file if np.random.rand() < split_ratio else test_file

file.write((image + "\t" + label + "\n").encode("utf-8"))

train_file.close()

test_file.close()

print("成功划分数据集!")

def read_image(path):

img = np.array(Image.open(path))

if img.ndim == 2:

img = cv.merge([img, img, img])

return img

def test_read():

train_file = "./test.data"

with open(train_file, 'rb') as f:

datalist = f.readlines()

datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

item = datalist[np.random.randint(42)]

image = read_image(item[0])

mask = read_image(item[1])

cv.imshow("image", image)

cv.imshow("mask", mask)

cv.waitKey(0)

cv.destroyAllWindows()

if __name__ == '__main__':

split_dataset()

test_read()

派生 Dataset 类

class MyDataset(Dataset):

def __init__(

self, data_file, data_dir, transform_trn=None, transform_val=None

):

"""

Args:

data_file (string): Path to the data file with annotations.

data_dir (string): Directory with all the images.

transform_{trn, val} (callable, optional): Optional transform to be applied

on a sample.

"""

with open(data_file, 'rb') as f:

datalist = f.readlines()

self.datalist = [(k, v) for k, v in map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)]

self.root_dir = data_dir

self.transform_trn = transform_trn

self.transform_val = transform_val

self.stage = 'train'

def set_stage(self, stage):

self.stage = stage

def __len__(self):

return len(self.datalist)

def __getitem__(self, idx):

img_name = os.path.join(self.root_dir, self.datalist[idx][0])

msk_name = os.path.join(self.root_dir, self.datalist[idx][1])

def read_image(x):

img_arr = np.array(Image.open(x))

if len(img_arr.shape) == 2: # grayscale

img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)

return img_arr

image = read_image(img_name)

mask = np.array(Image.open(msk_name))

if img_name != msk_name:

assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'

sample = {'image': image, 'mask': mask}

if self.stage == 'train':

if self.transform_trn:

sample = self.transform_trn(sample)

elif self.stage == 'val':

if self.transform_val:

sample = self.transform_val(sample)

return sample

构造DataLoader

# 定义Transform

composed_trn = transforms.Compose([ResizeShorterScale(shorter_side, low_scale, high_scale),

Pad(crop_size, [123.675, 116.28, 103.53], ignore_label),

RandomMirror(),

RandomCrop(crop_size),

Normalise(*normalise_params),

ToTensor()])

composed_val = transforms.Compose([Normalise(*normalise_params),

ToTensor()])

# 导入数据集

trainset = MyDataset(data_file=train_list,

data_dir=train_dir,

transform_trn=composed_trn,

transform_val=composed_val)

valset = MyDataset(data_file=val_list,

data_dir=val_dir,

transform_trn=None,

transform_val=composed_val)

# 构建生成器

train_loader = DataLoader(trainset,

batch_size=batch_size,

shuffle=True,

num_workers=num_workers,

pin_memory=True,

drop_last=True)

val_loader = DataLoader(valset,

batch_size=1,

shuffle=False,

num_workers=num_workers,

pin_memory=True)

训练

for i, sample in enumerate(train_loader):

image = sample['image'].cuda()

target = sample['mask'].cuda()

image_var = torch.autograd.Variable(image).float()

target_var = torch.autograd.Variable(target).long()

# Compute output

output = net(image_var)

...

原文链接:https://blog.csdn.net/Augurlee/article/details/103652444

用于DataLoader的pytorch数据集的更多相关文章

  1. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  2. Pytorch数据集读取

    Pytorch中数据集读取 在机器学习中,有很多形式的数据,我们就以最常用的几种来看: 在Pytorch中,他自带了很多数据集,比如MNIST.CIFAR10等,这些自带的数据集获得和读取十分简便: ...

  3. Pytorch数据集读入——Dataset类,实现数据集打乱Shuffle

    在进行相关平台的练习过程中,由于要自己导入数据集,而导入方法在市面上五花八门,各种库都可以应用,在这个过程中我准备尝试torchvision的库dataset torchvision.datasets ...

  4. [Pytorch数据集下载] 下载MNIST数据缓慢的方案

    步骤一 首先访问下面的网站,手工下载数据集.http://yann.lecun.com/exdb/mnist/ 把四个压缩包下载到任意文件夹,以便之后使用. 步骤二 把自己电脑上已经下载好的数据集的文 ...

  5. PyTorch 之 DataLoader

    DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...

  6. 什么是pytorch(4.数据集加载和处理)(翻译)

    数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...

  7. 【pytorch】torch.utils.data.DataLoader

    简介 DataLoader是PyTorch中的一种数据类型.用于训练/验证/测试时的数据按批读取. torch.utils.data.DataLoader(dataset, batch_size=1, ...

  8. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  9. [实现] 利用 Seq2Seq 预测句子后续字词 (Pytorch)2

    最近有个任务:利用 RNN 进行句子补全,即给定一个不完整的句子,预测其后续的字词.本文使用了 Seq2Seq 模型,输入为 5 个中文字词,输出为 1 个中文字词.目录 关于RNN 语料预处理 搭建 ...

随机推荐

  1. CentOS yum 安装历史版本 java

    1.以1.6为例,找到对应版本 $ yum --showduplicate list java* |grep 1.6 java--openjdk.x86_64 :1.6.0.41-1.13.13.1. ...

  2. 二进制安装MySQL5.6 MySQL5.7

    1:系统版本 [root@vhost1 ~]# cat /etc/redhat-release Red Hat Enterprise Linux Server release 6.5 (Santiag ...

  3. Oracle Flashback Drop

    Ensure that the prerequisites described in Prerequisites of Flashback Drop are met. The following li ...

  4. duliu题之狼抓兔子题解

    拖了将近5天的正解和AC.........emmmmm........... 事实告诉我们这种毒瘤题一定要建双向边(用了不知道多少个小时质疑建边的人欲哭无泪) 心态爆炸的传送 题了个面 这是个求最小割 ...

  5. 在Linux环境中运行python 项目

    1首先创建一个虚拟环境或者在一个已有的虚拟环境中创建一个django项目 1.1 创建一个虚拟环境: mkvirtualenv my_django115 这会在 ~/Envs 中创建 my_djang ...

  6. leetcode 62. 不同路径(C++)

    一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为“Start” ). 机器人每次只能向下或者向右移动一步.机器人试图达到网格的右下角(在下图中标记为“Finish”). 问总共有多 ...

  7. 012-elasticsearch5.4.3【五】-搜索API【一】搜索匹配所有matchAllQuery、全文查询[matchQuery、multiMatchQuery、commonTermsQuery、queryStringQuery、simpleQueryStringQuery]

    一.概述 查询所使用的 QueryBuilders来源于以下 import static org.elasticsearch.index.query.QueryBuilders.*; 请注意,您可以使 ...

  8. BIN转换成HEX格式及HEX转换成BIN的两个函数接口

    unsigned char HEX2BYTE(unsigned char hex_ch) { ') { '; } if (hex_ch >= 'a' && hex_ch < ...

  9. ubuntu 上用virtualenv安装python不同版本的开发环境。

    1.用pip安装virtualenv apt-get install python-virtualenv 2.创建python2的虚拟环境,进入要创建虚拟环境的目录下,我是放在/home/pyenv/ ...

  10. 20190906 On Java8 第十八章 字符串

    第十八章 字符串 +的重载与StringBuilder 用于String的+与+=是Java中仅有的两个重载过的操作符,Java不允许程序员重载任何其他的操作符.编译器自动引入了java.lang.S ...