前言 

本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强。然后,介绍了分布式训练的数据加载方式,数据读取的整个流程,当面对超大数据集时,内存不足的改进思路。

本文延续了以往的写作态度和风格,即便是自己知道的内容,也仍然在写之前看了很多的文章来保证内容的正确性和全面性,因此写得极累,耗费时间较长。若有读者看完后觉得有所帮助,文末可以赞赏一点。

文末扫描二维码关注公众号CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读,招聘信息发布。

(零) 概述


浮躁是人性的一个典型的弱点,很多人总擅长看别人分享的现成代码解读的文章,看起来学会了好多东西,实际上仍然不具备自己从零搭建一个pipeline的能力。

在公众号(CV技术指南)的交流群里(群内交流氛围不错,有需要的请关注公众号加群),常有不少人问到一些问题,根据这些问题明显能看出是对pipeline不了解,却已经在搞项目或论文了,很难想象如果基本的pipeline都不懂,如何分析代码问题所在?如何分析结果不正常的可能原因?遇到问题如何改?

Pytorch在这几年逐渐成为了学术上的主流框架,其具有简单易懂的特点。网上有很多pytorch的教程,如果是一个已经懂的人去看这些教程,确实pipeline的要素都写到了,感觉这教程挺不错的。但实际上更多地像是写给自己看的一个笔记,记录了pipeline要写哪些东西,却没有介绍要怎么写,为什么这么写,刚入门的小白看的时候容易云里雾里。

鉴于此,本教程尝试对于pytorch搭建一个完整pipeline写一个比较明确且易懂的说明。

本教程将介绍以下内容:

  1. 准备数据,自定义classdataset,分布式训练的数据加载方式,加载超大数据集的改进思路。

  2. 搭建模型与模型初始化。

  3. 编写训练过程,包括加载预训练模型、设置优化器、设置损失函数等。

  4. 可视化并保存训练过程。

  5. 编写推理函数。

(一)数据读取


classdataset的定义

先来看一个完整的classdataset

import torch.utils.data as data
import torchvision.transforms as transforms class MyDataset(data.Dataset):
def __init__(self,data_folder):
self.data_folder = data_folder
self.filenames = []
self.labels = [] per_classes = os.listdir(data_folder)
for per_class in per_classes:
per_class_paths = os.path.join(data_folder, per_class)
label = torch.tensor(int(per_class)) per_datas = os.listdir(per_class_paths)
for per_data in per_datas:
self.filenames.append(os.path.join(per_class_paths, per_data))
self.labels.append(label) def __getitem__(self, index):
image = Image.open(self.filenames[index])
label = self.labels[index]
data = self.proprecess(image)
return data, label def __len__(self):
return len(self.filenames) def proprecess(self,data):
transform_train_list = [
transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
transforms.Pad(self.opt.pad, padding_mode='edge'),
transforms.RandomCrop((self.opt.h, self.opt.w)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
return transforms.Compose(transform_train_list)

  

classdataset的几个要点:

  1. classdataset类继承torch.utils.data.dataset。

  2. classdataset的作用是将任意格式的数据,通过读取、预处理或数据增强后以tensor的形式输出。其中任意格式的数据可能是以文件夹名作为类别的形式、或以txt文件存储图片地址的形式、或视频、或十几帧图像作为一份样本的形式。而输出则指的是经过处理后的一个batch的tensor格式数据和对应标签。

  3. classdataset主要有三个函数要完成:__init__函数、__getitem__ 函数和__len__函数。

__init__函数

init函数主要是完成两个静态变量的赋值。一个是用于存储所有数据路径的变量,变量的每个元素即为一份训练样本,(注:如果一份样本是十几帧图像,则变量每个元素存储的是这十几帧图像的路径),可以命名为self.filenames。一个是用于存储与数据路径变量一一对应的标签变量,可以命名为self.labels。

假如数据集的格式如下:

#这里的0,1指的是类别0,1
/data_path/0/image0.jpg
/data_path/0/image1.jpg
/data_path/0/image2.jpg
/data_path/0/image3.jpg
......
/data_path/1/image0.jpg
/data_path/1/image1.jpg
/data_path/1/image2.jpg
/data_path/1/image3.jpg

  

可通过per_classes = os.listdir(data_path) 获得所有类别的文件夹,在此处per_classes的每个元素即为对应的数据标签,通过for遍历per_classes即可获得每个类的标签,将其转换成int的tensor形式即可。在for下获得每个类下每张图片的路径,通过self.join获得每份样本的路径,通过append添加到self.filenames中。

__getitem__ 函数

getitem 函数主要是根据索引返回对应的数据。这个索引是在训练前通过dataloader切片获得的,这里先不管。它的参数默认是index,即每次传回在init函数中获得的所有样本中索引对应的数据和标签。因此,可通过下面两行代码找到对应的数据和标签。

image = Image.open(self.filenames[index]))
label = self.labels[index]

  

获得数据后,进行数据预处理。数据预处理主要通过 torchvision.transforms 来完成,这里面已经包含了常用的预处理、数据增强方式。其完整使用方式在官网有详细介绍:https://pytorch.org/vision/stable/transforms.html

上面这里介绍了最常用的几种,主要就是resize,随机裁剪,翻转,归一化等。

最后通过transforms.Compose(transform_train_list)来执行。

除了这些已经有的数据增强方式外,在《数据增强方法总结》中还介绍了十几种特殊的数据增强方式,像这种自己设计了一种新的数据增强方式,该如何添加进去呢

下面以随机擦除作为例子。

class RandomErasing(object):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
Args:
probability: The probability that the Random Erasing operation will be performed.
sl: Minimum proportion of erased area against input image.
sh: Maximum proportion of erased area against input image.
r1: Minimum aspect ratio of erased area.
mean: Erasing value.
"""
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
self.probability = probability
self.mean = mean
self.sl = sl
self.sh = sh
self.r1 = r1 def __call__(self, img):
if random.uniform(0, 1) > self.probability:
return img
for attempt in range(100):
area = img.size()[1] * img.size()[2]
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.size()[2] and h < img.size()[1]:
x1 = random.randint(0, img.size()[1] - h)
y1 = random.randint(0, img.size()[2] - w)
if img.size()[0] == 3:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
else:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
return img
return img

  

如上所示,自己写一个类RandomErasing,继承object,在call函数里完成你的操作。在transform_train_list里添加上RandomErasing的定义即可。

transform_train_list = [
transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
transforms.Pad(self.opt.pad, padding_mode='edge'),
transforms.RandomCrop((self.opt.h, self.opt.w)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0])
#添加到这里
]

  

__len__函数

len函数主要就是返回数据长度,即样本的总数量。前面介绍了self.filenames的每个元素即为每份样本的路径,因此,self.filename的长度就是样本的数量。通过return len(self.filenames)即可返回数据长度。

验证classdataset

train_dataset = My_Dataset(data_folder=data_folder)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
print('there are total %s batches for train' % (len(train_loader))) for i,(data,label) in enumerate(train_loader):
print(data.size(),label.size())

  

分布式训练的数据加载方式


前面介绍的是单卡的数据加载,实际上分布式也是这样,但为了高速高效读取,每张卡上也会保存所有数据的信息,即self.filenames和self.labels的信息。只是在DistributedSampler 中会给每张卡分配互不交叉的索引,然后由torch.utils.data.DataLoader来加载。

dataset = My_Dataset(data_folder=data_folder)
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)

数据读取的完整流程


结合上面这段代码,在这里,我们介绍以下读取数据的整个流程。

  1. 首先定义一个classdataset,在初始化函数里获得所有数据的信息。

  2. classdataset中实现getitem函数,通过索引来获取对应的数据,然后对数据进行预处理和数据增强。

  3. 在模型训练前,初始化classdataset,通过Dataloader来加载数据,其加载方式是通过Dataloader中分配的索引,调用getitem函数来获取。

    关于索引的分配,在单卡上,可通过设置shuffle=True来随机生成索引顺序;在多机多卡的分布式训练上,shuffle操作通过DistributedSampler来完成,因此shuffle与sampler只能有一个,另一个必须为None。

超大数据集的加载思路


问题所在

再回顾一下上面这个流程,前面提到所有数据信息在classdataset初始化部分都会保存在变量中,因此当面对超大数据集时,会出现内存不足的情况。

思路

将切片获取索引的步骤放到classdataset初始化的位置,此时每张卡都是保存不同的数据子集。通过这种方式,可以将内存用量减少到原来的world_size倍(world_size指卡的数量)。

参考代码

class RankDataset(Dataset):
'''
实际流程
获取rank和world_size 信息 -> 获取dataset长度 -> 根据dataset长度产生随机indices ->
给不同的rank 分配indices -> 根据这些indices产生metas
'''
def __init__(self, meta_file, world_size, rank, seed):
super(RankDataset, self).__init__()
random.seed(seed)
np.random.seed(seed)
self.world_size = world_size
self.rank = rank
self.metas = self.parse(meta_file) def parse(self, meta_file):
dataset_size = self.get_dataset_size(meta_file) # 获取metafile的行数
local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size) # 根据world size和rank,获取当前epoch,当前rank需要训练的index。
self.metas = self.read_file(meta_file, local_rank_index) def __getitem__(self, idx):
return self.metas[idx] def __len__(self):
return len(self.metas) ##train
for epoch_num in range(epoch_num):
dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num)
sampler = RandomSampler(datset)
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=False,
num_workers=4,
sampler=sampler)

  

但这种思路比较明显的问题时,为了让每张卡上在每个epoch都加载不同的训练子集,因此需要在每个epoch重新build dataloader。

这一节参考链接:https://zhuanlan.zhihu.com/p/357809861

总结


本篇文章介绍了数据读取的完整流程,如何自定义classdataset,如何进行数据增强,自己设计的数据增强如何写,分布式训练是如何加载数据的,超大数据集的数据加载改进思路。

相信读完本文的读者对数据读取有了比较清晰的认识,下一篇将介绍搭建模型与模型初始化。

关注公众号可加计算机视觉交流群

欢迎关注公众号 CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读。

在公众号中回复关键字 “入门指南“可获取计算机视觉入门所有必备资料。

其它文章

自编码器综述论文:概念、图解和应用

解决图像分割落地场景真实问题,港中文等提出:开放世界实体分割

资源分享 |Nebullvm:一行代码测试多个DL编译器,模型推理提高5-20倍

目标检测、实例分割、多目标跟踪的Anchor-free应用方法总结

Soft Sampling:探索更有效的采样策略

如何解决工业缺陷检测小样本问题

机器学习、深度学习面试知识点汇总

深度学习图像识别的未来:机遇与挑战并存

招聘 | 22-65k!迁移科技:招聘深度学习、传统视觉、3D视觉算法工程师、项目经理、机械设计

关于快速学习一项新技术或新领域的一些个人思维习惯与思想总结

计算机视觉中的图像标注工具总结

计算机视觉中的神经网络可视化工具与项目

计算机视觉中的高效阅读论文的方法总结

计算机视觉中的transformer模型创新思路总结

一文概括机器视觉常用算法以及常用开发库

HOG和SIFT图像特征提取简述|  特征金字塔技术总结

目标检测中回归损失函数总结|    实例分割综述总结综合整理版

2021年小目标检测最新研究综述    |小目标检测常用方法总结

从零搭建Pytorch模型教程(一)数据读取的更多相关文章

  1. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

  2. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

  3. 新入手服务器不会玩?抢占式实例服务器教程,从零搭建tomcat超简流程

    新入手服务器不会玩?抢占式实例服务器教程,从零搭建tomcat超简流程 相信很多新人入手Linux服务器后,一脸无奈,这黑框框究竟能干啥?忽觉巨亏血亏不是? 这里面门道可不是你想象中的那么点,简则服务 ...

  4. 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型

    MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...

  5. Pytorch系列教程-使用Seq2Seq网络和注意力机制进行机器翻译

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutor ...

  6. Note | PyTorch官方教程学习笔记

    目录 1. 快速入门PYTORCH 1.1. 什么是PyTorch 1.1.1. 基础概念 1.1.2. 与NumPy之间的桥梁 1.2. Autograd: Automatic Differenti ...

  7. Pytorch系列教程-使用字符级RNN对姓名进行分类

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ...

  8. 生产与学术之Pytorch模型导出为安卓Apk尝试记录

    生产与学术 写于 2019-01-08 的旧文, 当时是针对一个比赛的探索. 觉得可能对其他人有用, 就放出来分享一下 生产与学术, 真实的对立... 这是我这两天对pytorch深度学习->a ...

  9. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

随机推荐

  1. FHQtreap(我有个绝妙的理解方法,但课的时间不够[doge])

    FHQtreap板子(P1486 [NOI2004] 郁闷的出纳员) 会了FHQ,treap什么的就忘了吧...... #include<bits/stdc++.h> using name ...

  2. Kubernetes:Pod基础知识总结

    Blog:博客园 个人 官方文档详尽介绍了Pod的概念. 概念 Pods are the smallest deployable units of computing that you can cre ...

  3. 【Containerd版】Kubeadm高可用安装K8s集群1.23+

    目录 基本环境配置 节点规划 网段规划及软件版本 基本配置 内核升级配置 K8s组件及Runtime安装 Containerd安装 K8s组件安装 高可用实现 集群初始化 Master01初始化 添加 ...

  4. 记一次.net core 异步线程设置超时时间

    前言: 刷帖子看到一篇 Go 记录一次groutine通信与context控制 看了一下需求背景,挺有意思的,琢磨了下.net core下的实现 需求背景: 项目中需要定期执行任务A来做一些辅助的工作 ...

  5. JAVA变量初始化赋值问题

    感谢大佬:https://www.cnblogs.com/znsongshu/p/6282672.html 在Java中,null值表示引用不指向任何对象.运行过程中系统发现使用了这样一个引用时·可以 ...

  6. curl 查看接口的网络分段响应时间

    示例如下 curl -o /dev/null -s -w %{time_namelookup}::%{time_connect}::%{time_starttransfer}::%{time_tota ...

  7. byte溢出栗子

    原创:转载需注明原创地址 https://www.cnblogs.com/fanerwei222/p/11634402.html byte溢出测试: byte b1 = (byte) 127; byt ...

  8. Java链式写法

    原创:转载需注明原创地址 https://www.cnblogs.com/fanerwei222/p/11613067.html Java 链式写法:详细看代码 package chain; /** ...

  9. 计算机辅助数据绘图(matlab\python\js)

    1. matlab绘图 官方说明:https://ww2.mathworks.cn/help/matlab/creating_plots/types-of-matlab-plots.html 基本图形 ...

  10. 二叉树的基本操作(C语言版)

    今天走进数据结构之二叉树 二叉树的基本操作(C 语言版) 1 二叉树的定义 二叉树的图长这样: 二叉树是每个结点最多有两个子树的树结构,常被用于实现二叉查找树和二叉堆.二叉树是链式存储结构,用的是二叉 ...