图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

一、所有图片放在一个文件夹内

这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test)) f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()

经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

前期工作就装备好了,接着就进入正题了:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image def default_loader(path):
return Image.open(path).convert('RGB') class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label def __len__(self):
return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()

自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

二、不同类别的图片放在不同的文件夹内

同样先准备数据,这里以flowers数据集为例,下载:

http://download.tensorflow.org/example_images/flower_photos.tgz

花总共有五类,分别放在5个文件夹下。大致如下图:

我的路径是d:/flowers/.

数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
) print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size()) show_batch(batch_x)
plt.axis('off')
plt.show()

就是这样。

pytorch学习:准备自己的图片数据的更多相关文章

  1. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

  2. pytorch: 准备、训练和测试自己的图片数据

    大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...

  3. pytorch初步学习(一):数据读取

    最近从tensorflow转向pytorch,感受到了动态调试的方便,也感受到了一些地方的不同. 所有实验都是基于uint16类型的单通道灰度图片. 一开始尝试用opencv中的cv.imread读取 ...

  4. [深度学习] pytorch利用Datasets和DataLoader读取数据

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...

  5. Python库 - Albumentations 图片数据增强库

    Python图像处理库 - Albumentations,可用于深度学习中网络训练时的图片数据增强. Albumentations 图像数据增强库特点: 基于高度优化的 OpenCV 库实现图像快速数 ...

  6. 【深度学习】Pytorch学习基础

    目录 pytorch学习 numpy & Torch Variable 激励函数 回归 区分类型 快速搭建法 模型的保存与提取 批训练 加速神经网络训练 Optimizer优化器 CNN MN ...

  7. tensorflow学习笔记三:实例数据下载与读取

    一.mnist数据 深度学习的入门实例,一般就是mnist手写数字分类识别,因此我们应该先下载这个数据集. tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们 ...

  8. Caffe初试(三)使用caffe的cifar10网络模型训练自己的图片数据

    由于我涉及一个车牌识别系统的项目,计划使用深度学习库caffe对车牌字符进行识别.刚开始接触caffe,打算先将示例中的每个网络模型都拿出来用用,当然这样暴力的使用是不会有好结果的- -||| ,所以 ...

  9. 纠错:基于FPGA串口发送彩色图片数据至VGA显示

    今天这篇文章是要修改之前的一个错误,前面我写过一篇基于FPGA的串口发送图片数据至VGA显示的文章,最后是显示成功了,但是显示的效果图,看起来确实灰度图,当时我默认我使用的MATLAB代码将图片数据转 ...

随机推荐

  1. GLSL ES 3.0 和 2.0 的区别

    GLSL ES 3.0 和 2.0 的区别 语法区别 attribute和varying. 取而代之的是 in和out 头文件多了个#version 300 es 纹理 texture2D 和 tex ...

  2. 小程序实现GBK编码数据转为Unicode/UTF8

    首先,不存在一种计算算法将GBK编码转换为Unicode编码,因为这两套编码本身毫无关系. 要想实现两者之间的互转,只能通过查表法实现. 在浏览器中实现编码转换,只需要简单两句: var x = ne ...

  3. 在ideaUI中建立maven项目

    1.新建New 注意:第一次创建maven项目需要在有网情况下进行 2.可以看见我们建立的项目了

  4. Java类是如何默认继承Object的

    前言 学过Java的人都知道,Object是所有类的父类.但是你有没有这样的疑问,我并没有写extends Object,它是怎么默认继承Object的呢? 那么今天我们就来看看像Java这种依赖于虚 ...

  5. 3dmax 3dmax计算机要求 3dmax下载

    渲染首先是要X64兼容台式电脑,笔记本不行,笔记本就是学生拿来玩还行,渲染大图笔记本真的是发热. 配置一般的电脑和笔记本千万不要尝试安装3dmax2019了,很卡的,3dmax2019只有64位,没有 ...

  6. AES加密算法详解

    AES 是一个对称密码分组算法,分组长度为128bit,密钥长度为128.192 和 256 bit. 整个加密过程如下图所示. 1.密钥生成算法 密钥扩展过程: 1)  将种子密钥按下图所示的格式排 ...

  7. KindEditor富文本编辑器, 从客户端中检测到有潜在危险的 Request.Form 值

    在用富文本编辑器时经常会遇到的问题是asp.net报的”检测到有潜在危险的 Request.Form 值“一般的解法是在aspx页面   page  标签中加上 validaterequest='fa ...

  8. 使用Skaffold一键将项目发布到Kubernetes

    当前skaffold版本为v0.4,还未发布正式版本,不建议在生产环境中使用: skaffold用于开发人员快速部署程序到Kubernetes中:skaffold提供了dev.run两种模式:使用sk ...

  9. 在Linux上搭建测试环境常用命令(转自-测试小柚子)

    一.搭建测试环境: 二.查看应用日志: (1)vivi/vim 原本是指修改文件,同时可以使用vi 日志文件名,打开日志文件(2)lessless命令是查看日志最常用的命令.用法:less 日志文件名 ...

  10. window下如何使用文本编辑器(如记事本)创建、编译和执行Java程序

    window下如何使用文本编辑器(如记事本)创建Java源代码文件,并编译执行 第一步:在一个英文目录下创建一个 .text 文件 第二步:编写代码 第三步:保存文件 方法一:选择 文件>另存为 ...