1.将图片的路径和标签写入csv文件并实现读取

  # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
for name in self.name2label.keys():
# pokemeon\\mew\\0001.jpg mew可以通过字典查看其类别
images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
images += glob.glob(os.path.join(self.root,name,'*.jpg'))
random.shuffle(images)
with open(os.path.join(self.root,filename),'w') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)
label = self.name2label[name[-2]]
writer.writerow([img,label]) # 从csv中读取文件
images, labels = [], []
with open(os.path.join(self.root,filename),'r') as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels) # 保证数据长度一致
       return images,labels

2.加载自定义数据集

 """
自定义数据集
image_resize
data argumentation(数据增强):Rotate,crop
normalize:mean,std
ToTensor """
import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import visdom class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in os.listdir(os.path.join(root)): #把文件和dir都会加载近来
if not sorted(os.path.isdir(os.path.join(root,name))):#排序后,文件夹顺序固定了
continue
self.name2label[name] = len(self.name2label.keys())
# name2label:{文件夹名,类别编号}
# 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0
self.images, self.labels = self.load_csv('images.csv')
# 对数据进行裁剪,mode:train-0.6,validation-0.2,test-0.2数据量是不同的
if mode == 'train':
self.images = self.images[:,int(len(self.images)*0.6)]
self.labels = self.labels[:,int(len(self.images)*0.6)]
elif mode == 'val':
self.images = self.images[int(len(self.images)*0.6):int(len(self.images)*0.8)]
self.labels = self.labels[int(len(self.labels)*0.6):int(len(self.labels)*0.8)]
else:
self.images = self.images[int(len(self.images) * 0.8):]
self.labels = self.labels[int(len(self.labels) * 0.8):] def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
for name in self.name2label.keys():
# pokemeon\\mew\\0001.jpg mew可以通过字典查看其类别
images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
images += glob.glob(os.path.join(self.root,name,'*.jpg'))
random.shuffle(images)
with open(os.path.join(self.root,filename),'w') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)
label = self.name2label[name[-2]]
writer.writerow([img,label])
# 从csv中读取文件
images, labels = [], []
with open(os.path.join(self.root,filename),'r') as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels) # 保证数据长度一致
return images,labels def __len__(self):
return len(self.images) def __getitem__(self, idx):
# idx是[0-len(self.images]
# self.images,self.label
# img:pokemeon\\mew\\0001.jpg(这是一个路径)要转变成img数据
# label:是数字
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),# string path -> img data
transforms.Resize(int(self.resize*1.25), int(self.resize*1.25)),
transforms.Randomrotation(15), # 旋转度数
transforms.CenterCrop(self.resize),#中心裁剪,保留resize大小
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]) # 归一化之后,范围为-1~1,之前的图片范围为0~1
])
img = tf(img) # 将path转换成数据
label = torch.tensor(label) # 将变量label转换成tensor
return img,label def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
# x:[c,h,w]
# x_hat = (x-mean)/std
# maen[3]->[3,1,1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std+mean
return x def main():
import torchvision
vis = visdom.Visdom()
"""
如果存储比较规范的话,可以使用下面简单的代码加载数据集,文件夹的标签从0开始编码
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
db = torchvision.datasets.ImageFolder('./pokemon',transform=tf)
loader = DataLoader(db,batch_size=32,shuffle=True)
print(db.class_to_idx) #查看类标签 """
db = Pokemon('./pokemon', 224, 'train') # 根据idx,返回一个
x,y = next(iter(db))
print('sample:',x.shape,y.shape)
#可视化
vis.image(db.denormalize(x),win='sample_x',opts=dict(title = 'sample_x'))
# 加载一批
loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )
for x,y in loader:
vis.images(db.denormalize(x), nrow=8, win='batch',opts=dict(title='batch'))
vis.text(str(y.numpy()),win='label',opts=dict(title='batch-y')) if __name__ == '__main__':
main()

小结:

在加载自定义数据集时,一般步骤

1.定义一个类继承Dataset

2.在类中读取数据集(图片的路径),重写len函数,和getitem函数

在len函数中返回数据集的长度

在getitem函数中,处理一张图片,单个图片路径转换成图片数据(包括transform转换),返回该图片数据和标签

3,将处理好的数据集(均为张量)放入DataLoader中,进行分批

loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )

4.训练时通过enumerate遍历每个batchsize

torch_13_自定义数据集实战的更多相关文章

  1. SpringBoot2.x过滤器Filter和使用Servlet3.0配置自定义Filter实战

    补充:SpringBoot启动日志 1.深入SpringBoot2.x过滤器Filter和使用Servlet3.0配置自定义Filter实战(核心知识) 简介:讲解SpringBoot里面Filter ...

  2. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  3. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  4. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  5. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  6. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

  7. Android自定义View实战(SlideTab-可滑动的选择器)

    转载请标明出处: http://blog.csdn.net/xmxkf/article/details/52178553 本文出自:[openXu的博客] 目录: 初步分析重写onDraw绘制 重写o ...

  8. 高级UI晋升之自定义view实战(七)

    更多Android高级架构进阶视频学习请点击:https://space.bilibili.com/474380680本篇文章自定义ViewGroup实现瀑布流效果来进行详解dispatchTouch ...

  9. 自定义View实战

    PS:上一篇从0开始学自定义View有博友给我留言说要看实战,今天我特意写了几个例子,供大家参考,所画的图案加上动画看着确实让人舒服,喜欢的博友可以直接拿到自己的项目中去使用,由于我这个写的是demo ...

随机推荐

  1. 【MySQL】MySQL 8.0的SYS视图

    MySQL的SYS视图 MySQL8.0的发展越来越趋同与Oracle,为了更好的监控MySQL的一些相关指标,出现了SYS视图,用于监控. 1.MySQL版本 (root@localhost) [s ...

  2. java 连缀用法

    连缀用法,即是在实例化对象时,同时为对象的属性设值. 如示例所示,在创建对象时,同时调用属性的设值函数,为属性赋值 Apple apple = new Apple() .setColor(" ...

  3. C++ 类的前向声明的用法

    我们知道C++的类应当是先定义,然后使用.但在处理相对复杂的问题.考虑类的组合时,很可能遇到俩个类相互引用的情况,这种情况称为循环依赖. 例如: class A { public: void f(B ...

  4. Java生鲜电商平台-B2B生鲜的互联网思维

    Java生鲜电商平台-B2B生鲜的互联网思维 在互联网高速发展的今天,为我们的生活带来了众多便利.然而互联网从早期的萌芽状态到现在妇孺皆知,它的崛起速度远远超乎世人的想象.人们开始关注互联网并且研究它 ...

  5. Javase之集合体系(4)之Map集合

    集合体系之Map集合 ##Map<K,V>( 接口 ) 特点:将键映射到值对象,一个映射不能包含重复的键:每个键只能映射一个值 Map集合与Collection集合的区别 ​ Map集合存 ...

  6. 高强度学习训练第八天总结:MySQL的一些优化

    为什么要做MYSQL优化 系统的吞吐量瓶颈往往出现在数据库的访问速度上 随着应用程序的运行,数据库中的数据会越来越多,处理时间会相应变慢. 数据是存放在磁盘上的,读写速度无法和内存相比 如何优化 设计 ...

  7. opencv::BackgroundSubtraction基本原理

    背景消除 BS算法 - 图像分割(GMM – 高斯混合模型) - 机器学习(KNN –K个最近邻) BackgroundSubtractor (父类) - BackgroundSubtractorMO ...

  8. 剑指offer 12:二进制中1的个数

    题目描述 输入一个整数,输出该数二进制表示中1的个数.其中负数用补码表示. 解法一:设置标志为flag=1,逐个位移至不同位置,比较是否为1. C++实现 class Solution { publi ...

  9. PHP-RPM 安装指南(亲测有用)

      小注:此教程可能有很多弯路,但是最终是肯定安装成功了的,一个问题就是刚开始安装编译的指令版本好像不对,但是后面纠正过来了,但是此教程一共遇到了 十多个问题,也一并解决了,具有一定的借鉴意义,还有( ...

  10. 1_Swift概况

    Swift 标准库 解决复杂的问题并编写高性能,可读的代码 概况 Swift标准库定义了用于编写Swift程序的基本功能,其中包括 1.如基本数据类型Int,Double以及String 2.共同的数 ...