文章目录:

1 Dataset基类

PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。

先看一下源码:

这里有一个__getitem__函数,__getitem__函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再讲)。

2 构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

import torch
from torch.utils.data import Dataset,DataLoader class MyDataset(Dataset):
def __init__(self):
self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])
self.label = torch.LongTensor([1,1,0,0]) def __getitem__(self,index):
return self.data[index],self.label[index] def __len__(self):
return len(self.data)

2.1 Init

  • 初始化中,一般是把数据直接保存在这个类的属性中。像是self.data,self.label

2.2 getitem

  • index是一个索引,这个索引的取值范围是要根据__len__这个返回值确定的,在上面的例子中,__len__的返回值是4,所以这个index会在0,1,2,3这个范围内。

3 dataloader

从上文中,我们知道了MyDataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。

继续上面的代码,我们接着写代码:

mydataloader = DataLoader(dataset=mydataset,
batch_size=1)

我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:

for i,(data,label) in enumerate(mydataloader):
print(data,label)

输出结果是:

意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的

我们稍微修改一下上面的DataLoader的参数:

mydataloader = DataLoader(dataset=mydataset,
batch_size=2,
shuffle=True) for i,(data,label) in enumerate(mydataloader):
print(data,label)

结果是:

可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:

两次结果不同,这是因为shuffle=True,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。

【个人感想】

Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:

  1. 一般标签值应该是Long整数的,所以标签的tensor可以用torch.LongTensor(数据)或者用.long()来转化成Long整数的形式。
  2. 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用to()放到GPU显存上进行GPU加速。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for i,(data,label) in enumerate(mydataloader):
data = data.to(device)
label = label.to(device)
print(data,label)

看一下输出:

【小白学PyTorch】3 浅谈Dataset和Dataloader的更多相关文章

  1. 【小白学C#】浅谈.NET中的IL代码

    一.前言 前几天群里有位水友提问:”C#中,当一个方法所传入的参数是一个静态字段的时候,程序是直接到静态字段拿数据还是从复制的函数栈中拿数据“.其实很明显,这和方法参数的传递方式有关,如果是引用传递的 ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  4. 【小白学PyTorch】8 实战之MNIST小试牛刀

    文章来自微信公众号[机器学习炼丹术].有什么问题都可以咨询作者WX:cyx645016617.想交个朋友占一个好友位也是可以的~好友位快满了不过. 参考目录: 目录 1 探索性数据分析 1.1 数据集 ...

  5. 【小白学PyTorch】4 构建模型三要素与权重初始化

    文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...

  6. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  7. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

  8. 【小白学PyTorch】16 TF2读取图片的方法

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...

  9. 【小白学PyTorch】17 TFrec文件的创建与读取

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

随机推荐

  1. luogu P1452 [USACO03FALL]Beauty Contest G /【模板】旋转卡壳

    LINK:旋转卡壳 如题 是一道模板题. 容易想到n^2暴力 当然也能随机化选点 (还真有人过了 考虑旋转卡壳 其实就是对于某个点来说找到其最远的点. 在找的过程中需要借助一下个点的帮助 利用当前点到 ...

  2. 浅谈二分图的最大匹配和二分图的KM算法

    二分图还可以,但是我不太精通.我感觉这是一个很烦的问题但是学网络流不得不学它.硬啃吧. 人比较蠢,所以思考几天才有如下理解.希望能说服我或者说服你. 二分图的判定不再赘述一个图是可被划分成一个二分图当 ...

  3. JS&ES6学习笔记(持续更新)

    ES6学习笔记(2019.7.29) 目录 ES6学习笔记(2019.7.29) let和const let let 基本用法 let 不存在变量提升 暂时性死区 不允许重复声明 块级作用域 级作用域 ...

  4. 使用hibernate validate做参数校验

    1.为什么使用hibernate validate ​ 在开发http接口的时候,参数校验是必须有的一个环节,当参数校验较少的时候,一般是直接按照校验条件做校验,校验不通过,返回错误信息.比如以下校验 ...

  5. CI4框架应用一 - 环境搭建

    CI框架 (codeigniter)算是一个古老的框架了,由于在工作中一直在使用这个框架,还是比较有感情的.我对CI的感觉就是,简单易用,学习曲线平滑,对于新手友好. 目前CI框架已经更新到CI4了, ...

  6. 20行代码教你用python给证件照换底色

    1.图片来源 该图片来源于百度图片,如果侵权,请联系我删除!图片仅用于知识交流.本文只是为了告诉大家:python其实有很多黑科技(牛逼的库),我们既可以用python处理工作中的一些事儿,同时我们也 ...

  7. mysql存储引擎InnoDB详解,从底层看清InnoDB数据结构

    InnoDB一个支持事务安全的存储引擎,同时也是mysql的默认存储引擎.本文主要从数据结构的角度,详细介绍InnoDB行记录格式和数据页的实现原理,从底层看清InnoDB存储引擎. 本文主要内容是根 ...

  8. GitLab 系列文章

    GitLab 系列文章 记录 GitLab 的相关文章 列表 Docker 搭建 GitLab GitLab CI/CD 配置 GitLab 配置模板 访问 GitLab 数据库 GitLab 转让所 ...

  9. PHP基础之查找

    前言 之前的文章介绍了PHP的运算符.流程控制.函数.排序等.有兴趣可以去看看. PHP入门之类型与运算符 PHP入门之流程控制 PHP入门之函数 PHP入门之数组 PHP基础之排序 下面简单介绍一下 ...

  10. Java不可不知的泛型使用

    前面的文章: 详解Java的对象创建 一文打尽Java继承的相关问题 一文打尽Java抽象类和接口的相关问题 本文介绍了Java的泛型的基本使用. 1. 为什么使用泛型 看下面一个例子: 为了说明问题 ...