本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码。所有逻辑需按自己需求另外实现:

一、分析DataLoader

train_loader = DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)

datasets.MNIST()是一个torch.utils.data.Datasets对象,batch_size表示我们定义的batch大小(即每轮训练使用的批大小),shuffle表示是否打乱数据顺序(对于整个datasets里包含的所有数据)。

对于batch_size和shuffle都是根据业务需求来认为指定的,不做过多说明。

对于Datasets对象来说,我们可以根据自己的数据类型来自定义,自己定义一个类,继承Datasets类。

二、分析Datasets类

class Dataset(object):
"""An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
""" def __getitem__(self, index):
raise NotImplementedError def __len__(self):
raise NotImplementedError def __add__(self, other):
return ConcatDataset([self, other])

上述代码是pytorch中Datasets的源码,注意成员方法__getitem__和__len__都是未实现的。我们要实现自定义Datasets类来完成数据的读取,则只需要完成这两个成员方法的重写。

  首先,__getitem__()方法用来从datasets中读取一条数据,这条数据包含训练图片(已CV距离)和标签,参数index表示图片和标签在总数据集中的Index。

  __len__()方法返回数据集的总长度(训练集的总数)。

三、简单实现MyDatasets类

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z' import os from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg # 对所有图片生成path-label map.txt
def generate_map(root_dir):
current_path = os.path.abspath(__file__)
father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".") with open(root_dir + 'map.txt', 'w') as wfp:
for idx in range(10):
subdir = os.path.join(root_dir, '%d/' % idx)
for file_name in os.listdir(subdir):
abs_name = os.path.join(father_path, subdir, file_name)
linux_abs_name = abs_name.replace("\\", '/')
wfp.write('{file_dir} {label}\n'.format(file_dir=linux_abs_name, label=idx)) # 实现MyDatasets类
class MyDatasets(Dataset): def __init__(self, dir):
# 获取数据存放的dir
# 例如d:/images/
self.data_dir = dir
# 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
self.image_target_list = []
# 从dir--label的map文件中将所有的tuple对读取到image_target_list中
# map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
with open(os.path.join(dir, 'map.txt'), 'r') as fp:
content = fp.readlines()
str_list = [s.rstrip().split() for s in content]
# 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
self.image_target_list = [(x[0], int(x[1])) for x in str_list] def __getitem__(self, index):
image_label_pair = self.image_target_list[index]
# 按path读取图片数据,并转换为图片格式例如[3,32,32]
img = mpimg.imread(image_label_pair[0])
return img, image_label_pair[1] def __len__(self):
return len(self.image_target_list) if __name__ == '__main__':
# 生成map.txt
# generate_map('train/') train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True) for step in range(20000):
for idx, (img, label) in enumerate(train_loader):
print(img.shape)
print(label.shape)

上述代码简要说明了利用Datasets类和DataLoader类来读取数据,本例用的是图片原始数据,大概的结构如下:

如果使用其他形式的数据,例如二进制文件,则需要字节读取文件,分割成每一张图片和label,然后从__getitem__中返回就可以了。例如cifar-10数据,我们只需要在__getitem__方法中,按index来读取对应位置的字节,然后转换为label和img,并返回。在__len__中返回cifar-10训练集的总样本数。DataLoader就可以根据我们提供的index,len以及batch_size,shuffle来返回相应的batch数据和label。

[深度学习] pytorch利用Datasets和DataLoader读取数据的更多相关文章

  1. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  2. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  3. [深度学习] pytorch学习笔记(2)(梯度、梯度下降、凸函数、鞍点、激活函数、Loss函数、交叉熵、Mnist分类实现、GPU)

    一.梯度 导数是对某个自变量求导,得到一个标量. 偏微分是在多元函数中对某一个自变量求偏导(将其他自变量看成常数). 梯度指对所有自变量分别求偏导,然后组合成一个向量,所以梯度是向量,有方向和大小. ...

  4. 深度学习PyTorch环境安装——mac

    参考:http://python.jobbole.com/87522/ 1.首先要安装Anaconda 1)什么是Anaconda Anaconda是Python的包管理器和环境管理器,是一个包含18 ...

  5. 深度学习(tensorflow) —— 自己数据集读取opencv

    先来看一下我们的目录: dataset1 和creat_dataset.py 属于同一目录 mergeImg1 和mergeImg2 为Dataset1的两子目录(两类为例子)目录中存储图像等文件 核 ...

  6. SPSS学习系列之SPSS Statistics导入读取数据(多种格式)(图文详解)

    不多说,直接上干货! SPSS Statistics导入读取数据的步骤: 文件  ->  导入数据 成功! 欢迎大家,加入我的微信公众号:大数据躺过的坑     免费给分享       同时,大 ...

  7. JMeter 参数化之利用JDBCConnectionConfiguration从数据库读取数据并关联变量

    参数化之利用DBC Connection Configuration从数据库读取数据并关联变量   by:授客 QQ:1033553122 1.   下载mysql jar包 下载mysql jar包 ...

  8. [深度学习] pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)

    一.继承nn.Module类并自定义层 我们要利用pytorch提供的很多便利的方法,则需要将很多自定义操作封装成nn.Module类. 首先,简单实现一个Mylinear类: from torch ...

  9. [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)

    一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...

随机推荐

  1. 图解DMZ

    图解DMZ 1. 概念介绍 DMZ是英文“demilitarized zone”的缩写,中文译为“隔离区”.“非军事区”.它是为了解决安装防火墙后外部网络不能访问内部网络服务器的问题,而设立的一个非安 ...

  2. java 实现读取某个目录下指定类型的文件

    我这里是读取txt类型的文件,在指定的目录下有不同类型的文件 实现代码,读取txt类型的文件并打印出该文件的绝对路径 package com.SBgong.test; import java.io.F ...

  3. PTA(Basic Level)1037.在霍格沃茨找零钱

    如果你是哈利·波特迷,你会知道魔法世界有它自己的货币系统 -- 就如海格告诉哈利的:"十七个银西可(Sickle)兑一个加隆(Galleon),二十九个纳特(Knut)兑一个西可,很容易.& ...

  4. [转帖]流言终结者 —— “SQL Server 是Sybase的产品而不是微软的”

    流言终结者 —— “SQL Server 是Sybase的产品而不是微软的” https://www.cnblogs.com/xxxtech/archive/2011/12/30/2307859.ht ...

  5. 【3.2】【mysql基本实验】mysql GTID复制(基于空数据的配置)

    概述:本质上和传统异步复制没什么区别,就是加了GTID参数. 且可以用传统的方式来配置主从,也可以用GTID的方式来自动配置主从. 这里使用GTID的方式来自动适配主从. 需要mysql5.6.5以上 ...

  6. 结合docker做flask+kafka数据接口与压力测试

    一.需求 需要做实时数据接入的接口.数据最终要写入库,要做到高并发,数据的完整,不丢失数据. 二.技术选型 1.因为只是做简单的接口,不需要复杂功能,所以决定用flask这个简单的python框架(因 ...

  7. Django中ajax发送post请求,报403错误CSRF验证失败解决办法

    今天学习Django框架,用ajax向后台发送post请求,直接报了403错误,说CSRF验证失败:先前用模板的话都是在里面加一个 {% csrf_token %} 就直接搞定了CSRF的问题了:很显 ...

  8. 喝奶茶最大值(不能喝自己班级的)2019 Multi-University Training Contest 8--hdu杭电第8场(Roundgod and Milk Tea)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6667 题意: 有 n个班级,每个班级有a个人.b个奶茶,每个班的人不能喝自己的奶茶,只能喝别人班的奶茶 ...

  9. [BJOI2014]大融合(Link Cut Tree)

    [BJOI2014]大融合(Link Cut Tree) 题面 给出一棵树,动态加边,动态查询通过每条边的简单路径数量. 分析 通过每条边的简单路径数量显然等于边两侧节点x,y子树大小的乘积. 我们知 ...

  10. 公司PL/SQL考核及小结

    一.数据库初始化脚本 -- Create table 学生信息 drop table HAND_STUDENT; create table HAND_STUDENT ( STUDENT_NO ) no ...