DataLoader与Dataset

pytorch中的数据读取机制

graph TB
DataLoader --> DataLoaderIter
DataLoaderIter --> Sampler
Sampler --> Index
Sampler --> DatasetFetcher
Index -->DatasetFetcher
DatasetFetcher -->Dataset
Dataset --> getitem
getitem -->img,label
img,label --> collate_fn
collate_fn --> BatchData

  1. 人民币二分类

    可以把人民币当成自变量x,类别是y。

    数据模块可以分为
  2. 数据收集->原始样本和标签,img,label
  3. 数据划分->划分train,valid,test。验证集来调整过拟合
  4. 数据读取->数据读取,DataLoader

    DataLoader分为两个子模块,分别是
  • Sampler生成索引,样本的序号index
  • DataSet根据索引,读取img和label
  1. 数据预处理->transforms
  2. DataLoader与Dataset

    DataLoader和Dataset是数据读取的核心
  3. DataLoader

    DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,work_init_fn=None,multiprocessing_context=None)

    主要是构建可迭代的数据转载器

    dataloader,我们在训练的时候在每一次循环中,就是从dataset中读取每一个batch_size大小的数据
  • dataset:Dataset类,决定数据从哪读取及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

    epoch,iteration,batchsize
  • Epoch:所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

    样本总数:80,BatchSize:8

    1 Epoch = 10 Iteration

    如果样本总数不能被整除

    样本总数:87,Batchsize:8
  • 1 Epoch = 10 Iteration,drop_last=True
  • 1 Epoch = 11 Iteration,drop_last=False
  1. Dataset

    torch.utils.data.Dataset

    class Dataset(object):

    def getitem(self,index):

    ​ raise NotImplementedError

    def add(self,other):

    ​ return ConcatDataset([self,other])

    功能:Dataset抽象类,所有自定义的Dataset需要继承,并复写

    __getitem__()

    getitem: 接收一个索引,返回一个样本

数据读取机制

  1. 读哪些数据,在每一个iteration中读取哪些数据?
  2. 从哪读数据,在硬盘中如何读取?
  3. 怎么读数据?
import os
import random
import shutil
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir) if __name__ == '__main__': DATA_DIR = os.path.abspath(os.path.join(BASE_DIR, ".", "RMB_data"))
SPLIT_DIR = os.path.abspath(os.path.join(BASE_DIR, ".", "rmb_split"))
TRAIN_DIR = os.path.join(SPLIT_DIR, "train")
VALID_DIR = os.path.join(SPLIT_DIR, "valid")
TEST_DIR = os.path.join(SPLIT_DIR, "test") if not os.path.exists(DATA_DIR):
raise Exception("\n{}不存在,请下载RMBdata放到{}路径下".format(DATA_DIR, os.path.dirname(DATA_DIR))) train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1 for paths, dirs, files in os.walk(DATA_DIR):
for sub_dirs in dirs:
imgs = os.listdir(os.path.join(paths, sub_dirs))
imgs = list(filter(lambda x: x.endswith('.jpg'),imgs))
# print(imgs)
random.shuffle(imgs)
# print(imgs)
imgs_count = len(imgs)
# print(imgs_count) train_pic = int(train_pct*imgs_count)
valid_pic = int((valid_pct+train_pct)*imgs_count) if imgs_count == 0 :
print("{}目录下,无图片,请检查".format(os.path.join(paths, sub_dirs)))
import sys
sys.exit(0) for i in range(imgs_count):
if i < train_pic :
out_dir = os.path.join(TRAIN_DIR, sub_dirs)
elif i < valid_pic :
out_dir = os.path.join(VALID_DIR, sub_dirs)
else:
out_dir = os.path.join(TEST_DIR, sub_dirs) makedir(out_dir) target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(DATA_DIR, sub_dirs, imgs[i]) shutil.copy(src_path, target_path) print("Class:{}, train:{}, valid:{}, test:{}".format(sub_dirs, train_pic, valid_pic-train_pic, imgs_count-valid_pic-train_pic))
print("已在{}划分好".format(out_dir)
Class:1, train:80, valid:10, test:-70
已在D:\pythonProject\04_DataLoader\rmb_split\test\1划分好
Class:100, train:80, valid:10, test:-70
已在D:\pythonProject\04_DataLoader\rmb_split\test\100划分好
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
BASE_PATH = os.path.abspath(__file__)
# print(BASE_PATH)
base_path = os.path.abspath(os.path.join(BASE_PATH, '..', 'TestDir'))
# print(base_path)
data_dir = os.path.abspath(os.path.join(BASE_PATH, '..', 'RMB_data'))
random.seed(1)
# print(data_dir)
test_label = {"1": 0, "100": 1}
data_info = list()
for path, dirs, files in os.walk(base_path):
for sub_dir in dirs:
# print(sub_dir)
sub_dirlist = os.listdir(os.path.join(base_path, sub_dir))
pynames = list(filter(lambda y: y.endswith('.jpg'), sub_dirlist))
# print(pynames)
# print(test_label[sub_dir])
for pyname in pynames:
datainfo_dir = os.path.join(base_path, sub_dir, pyname)
t_label=test_label[sub_dir]
t_label = int(t_label)
data_info.append((datainfo_dir, t_label))
# print(data_info)
new_data_info = list()
for data_info_e in data_info:
x_dir, x_label = data_info_e
x_img = Image.open(x_dir).convert('RGB')
ok_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
])
x_img = ok_transform(x_img)
new_data_info.append((x_img,x_label)) # print(len(new_data_info[0][0]))
print(len(new_data_info))
newdataLoader = DataLoader(new_data_info,batch_size=14, shuffle=True)
for ids, data in enumerate(newdataLoader):
print(ids)

pytorch(07)数据模型的读取的更多相关文章

  1. pytorch(08)数据模型的读取(2)

    import numpy as np import torch import os import random from PIL import Image from torch.utils.data ...

  2. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  3. 【小白学PyTorch】16 TF2读取图片的方法

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...

  4. pytorch实现花朵数据集读取

    import os from PIL import Image from torch.utils import data import numpy as np from torchvision imp ...

  5. Pytorch使用PIL的读取单张图片并显示

    1. Image.open(fp, mode="r") 调用此方法需要引入头文件:from PIL import Image. 参数说明: fp:图片路径,可为绝对路径或相对路径. ...

  6. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  7. PyTorch使用总览

    PyTorch使用总览 https://blog.csdn.net/u014380165/article/details/79222243 深度学习框架训练模型时的代码主要包含数据读取.网络构建和其他 ...

  8. pytorch的torch.utils.data.DataLoader认识

    PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口, 该接 ...

  9. PyTorch源码解读之torch.utils.data.DataLoader(转)

    原文链接 https://blog.csdn.net/u014380165/article/details/79058479 写得特别好!最近正好在学习pytorch,学习一下! PyTorch中数据 ...

随机推荐

  1. 【noi 2.6_8464】股票买卖(DP)

    题意:N天可买卖2次股票,问最大利润. 解法:f[i]表示前 i 天买卖一次的最大利润,g[i]表示后 i 天. 注意--当天可以又买又卖,不要漏了这个要求:数据较大. 1 #include<c ...

  2. 牛客编程巅峰赛S2第7场 - 钻石&王者 A.牛牛的独特子序列 (字符串,二分)

    题意:给你一个字符串,找出一个类似为\(aaabbbccc\)这样的由连续的\(abc\)构成的子序列,其中\(|a|=|b|=|c|\),问字符串中能构造出的子序列的最大长度. 题解:这题刚开始一直 ...

  3. Redundant Paths POJ - 3177 把原图变成边—双连通图

    无向图概念:(这里的x->y表示x和y之间有一条无向边)1.桥:对于一个无向图,如果删除某条边后,该图的连通分量增加,则称这条边为桥 比如1->2->3->4这样一个简单得图一 ...

  4. c语言中qsort函数的使用、编程中的一些错误

    qsort()函数: 功能:相当于c++sort,具有快排的功能,复杂度的话nlog(n)注:C中的qsort()采用的是快排算法,C++的sort()则是改进的快排算法.两者的时间复杂度都是nlog ...

  5. C# 之 async / await

    直接看一个例子 private async void button1_Click(object sender, EventArgs e) { var t = Task.Run(() => { T ...

  6. 【史上最全】Hadoop 核心 - HDFS 分布式文件系统详解(上万字建议收藏)

    1. HDFS概述 Hadoop 分布式系统框架中,首要的基础功能就是文件系统,在 Hadoop 中使用 FileSystem 这个抽象类来表示我们的文件系统,这个抽象类下面有很多子实现类,究竟使用哪 ...

  7. Kubernets二进制安装(5)之私有仓库harbor搭建

    在IP地址为192.168.80.50,机器名为mfyxw50上搭建私有仓库harbor harbor下载地址: harbor下载连接地址:https://github.com/goharbor/ha ...

  8. 编写一个c函数,该函数给出一个字节中被置为1的位的个数

    请编写一个c函数,该函数给出一个字节中被置为1的位的个数 #include <stdio.h> void fun(char ch) { int i; int temp; int count ...

  9. Web 前端页面性能监控指标

    Web 前端页面性能监控指标 性能监控 / 性能指标 / 性能优化 白屏时间计算 FCP 白屏时间:从浏览器输入地址并回车后到页面开始有内容的时间: 首屏时间计算 FMP 首屏时间:从浏览器输入地址并 ...

  10. Git Best Practice All In One

    Git Best Practice All In One git workflow 本地开发环境: 开发人员自测的,可以是自己本地部署的静态服务器,当然也可类似是运行 npm server类似的环境, ...