显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。

  1. 自定义数据集的方法

    首先创建一个Dataset类



    在代码中:

    def init() 一些初始化的过程写在这个函数下

    def len() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系

    def getitem() 回数据和标签

  2. 补充代码

    上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了

# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\\00000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\\00000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename) # 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels # 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系 def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]) ])
img=tf(img)
label=torch.tensor(label)
return img,label def main():
# 验证工作
viz=visdom.Visdom() db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10) if __name__ == '__main__':
main()

PyTorch 中自定义数据集的读取方法的更多相关文章

  1. Django 模型中自定义Manager和模型方法

    1.自定义管理器(Manager) 在语句Book.objects.all()中,objects是一个特殊的属性,通过它来查询数据库,它就是模型的一个Manager. 每个Django模型至少有一个m ...

  2. Pytorch中自定义神经网络卷积核权重

    1. 自定义神经网络卷积核权重 神经网络被深度学习者深深喜爱,究其原因之一是神经网络的便利性,使用者只需要根据自己的需求像搭积木一样搭建神经网络框架即可,搭建过程中我们只需要考虑卷积核的尺寸,输入输出 ...

  3. 转--Android中自定义字体的实现方法

    1.Android系统默认支持三种字体,分别为:“sans”, “serif”, “monospace 2.在Android中可以引入其他字体 . 复制代码 代码如下: <?xml versio ...

  4. Wordpress在主题或者插件中自定义存储附件的方法

    1.前端使用form表单 //单文件上传,我的业务需求中限制了必须上传图片 <input type="file" name="singlename" ac ...

  5. Vue中自定义指令的使用方法!

    除了核心功能默认内置的指令 (v-model 和 v-show),Vue 也允许注册自定义指令.注意,在 Vue2.0 中,代码复用和抽象的主要形式是组件.然而,有的情况下,你仍然需要对普通 DOM ...

  6. 【Azure 应用服务】在 App Service for Windows 中自定义 PHP 版本的方法

    问题描述 在App Service for Windows的环境中,当前只提供了PHP 7.4 版本的选择情况下,如何实现自定义PHP Runtime的版本呢? 如 PHP Version 8.1.9 ...

  7. pytorch 中交叉熵损失实现方法

  8. 6.1 如何在spring中自定义xml标签

    dubbo自定义了很多xml标签,例如<dubbo:application>,那么这些自定义标签是怎么与spring结合起来的呢?我们先看一个简单的例子. 一 编写模型类 package ...

  9. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  10. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

随机推荐

  1. xv6 中的进程切换:MIT6.s081/6.828 lectrue11:Scheduling 以及 Lab6 Thread 心得

    絮絮叨 这两节主要介绍 xv6 中的线程切换,首先预警说明,这节课程的容量和第 5/6 节:进程的用户态到内核态的切换一样,细节多到爆炸,连我自己复习时都有点懵,看来以后不能偷懒了,学完课程之后要马上 ...

  2. 领域驱动设计(DDD):DDD落地问题和一些解决方法

    欢迎继续关注本系列文章,下面我们继续讲解下DDD在实战落地时候,会具体碰到哪些问题,以及解决的方式有哪些. DDD 是一种思想,主要知道我们方向,具体如何做,需要我们根据业务场景具体问题具体分析. 充 ...

  3. 用OLED屏幕播放视频(2): 为OLED屏幕开发I2C驱动

    下面的系列文章记录了如何使用一块linux开发扳和一块OLED屏幕实现视频的播放: 项目介绍 为OLED屏幕开发I2C驱动 使用cuda编程加速视频处理 这是此系列文章的第2篇, 主要总结和记录一个I ...

  4. JDK8升级JDK11最全实践干货来了

    1.前言 截至目前(2023年),Java8发布至今已有9年,2018年9月25日,Oracle发布了Java11,这是Java8之后的首个LTS版本.那么从JDK8到JDK11,到底带来了哪些特性呢 ...

  5. Vue3中的Ref与Reactive:深入理解响应式编程

    前言 Vue 3是一个功能强大的前端框架,它引入了一些令人兴奋的新特性,其中最引人注目的是ref和reactive.这两个API是Vue 3中响应式编程的核心,本文将深入探讨它们的用法和差异. 什么是 ...

  6. Django框架项目——BBS项目介绍、表设计、表创建同步、注册、登录功能、登录功能、首页搭建、admin、头像、图片防盗、个人站点、侧边栏筛选、文章的详情页、点赞点踩、评论、后台管理、添加文章、头像

    文章目录 1 BBS项目介绍.表设计 项目开发流程 表设计 2 表创建同步.注册.登录功能 数据库表创建及同步 注册功能 登陆功能 3 登录功能.首页搭建.admin.头像.图片防盗.个人站点.侧边栏 ...

  7. vue 甘特图(三):甘特图右侧内容拖动展示

    vue3 甘特图(三):甘特图右侧内容拖动展示内容 解决因多个项目周期跨度不同,在一页屏幕里展示不完全,需要通过拖动甘特图下方的滚动条,去查看对应时间段内的内容 拖拽滚动视图,展示对应时间甘特图 构思 ...

  8. 嵌入式BI的精解与探索

    摘要:本文由葡萄城技术团队原创并首发.转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具.解决方案和服务,赋能开发者. 前言 1996年,商业智能(BI)的概念首次浮现,随后的20多年间,商 ...

  9. Redis的速度不够用?为什么你应该考虑使用 KeyDB,一个更快、更强大、更灵活的开源数据库

    你是否正在使用 Redis 作为您的数据结构存储,享受它的高性能.高可用的特性?如果是这样,那么你可能会对 KeyDB 感兴趣. 什么是 KeyDB? KeyDB 一个由 Snap 提供支持.专为扩展 ...

  10. JS个人总结(2)

    1.null被认为是一个空的对象引用..如果定义的变量准备将来用保存对象,最好将该变量初始化null.即 var x=null;这样只有检查null值就可以知道这个变量是否已经保存了一个对象.. 2. ...