图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

一、所有图片放在一个文件夹内

这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test)) f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()

经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

前期工作就装备好了,接着就进入正题了:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image def default_loader(path):
return Image.open(path).convert('RGB') class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label def __len__(self):
return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()

自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

二、不同类别的图片放在不同的文件夹内

同样先准备数据,这里以flowers数据集为例,下载:

http://download.tensorflow.org/example_images/flower_photos.tgz

花总共有五类,分别放在5个文件夹下。大致如下图:

我的路径是d:/flowers/.

数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
) print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size()) show_batch(batch_x)
plt.axis('off')
plt.show()

就是这样。

pytorch学习:准备自己的图片数据的更多相关文章

  1. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

  2. pytorch: 准备、训练和测试自己的图片数据

    大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...

  3. pytorch初步学习(一):数据读取

    最近从tensorflow转向pytorch,感受到了动态调试的方便,也感受到了一些地方的不同. 所有实验都是基于uint16类型的单通道灰度图片. 一开始尝试用opencv中的cv.imread读取 ...

  4. [深度学习] pytorch利用Datasets和DataLoader读取数据

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...

  5. Python库 - Albumentations 图片数据增强库

    Python图像处理库 - Albumentations,可用于深度学习中网络训练时的图片数据增强. Albumentations 图像数据增强库特点: 基于高度优化的 OpenCV 库实现图像快速数 ...

  6. 【深度学习】Pytorch学习基础

    目录 pytorch学习 numpy & Torch Variable 激励函数 回归 区分类型 快速搭建法 模型的保存与提取 批训练 加速神经网络训练 Optimizer优化器 CNN MN ...

  7. tensorflow学习笔记三:实例数据下载与读取

    一.mnist数据 深度学习的入门实例,一般就是mnist手写数字分类识别,因此我们应该先下载这个数据集. tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们 ...

  8. Caffe初试(三)使用caffe的cifar10网络模型训练自己的图片数据

    由于我涉及一个车牌识别系统的项目,计划使用深度学习库caffe对车牌字符进行识别.刚开始接触caffe,打算先将示例中的每个网络模型都拿出来用用,当然这样暴力的使用是不会有好结果的- -||| ,所以 ...

  9. 纠错:基于FPGA串口发送彩色图片数据至VGA显示

    今天这篇文章是要修改之前的一个错误,前面我写过一篇基于FPGA的串口发送图片数据至VGA显示的文章,最后是显示成功了,但是显示的效果图,看起来确实灰度图,当时我默认我使用的MATLAB代码将图片数据转 ...

随机推荐

  1. 开发环境中Docker的使用

    一. Ubuntu16.04+Django+Redis+Nginx的Web项目Docker化 1.创建Django项目的image # 创建项目image 执行 docker build -t ccn ...

  2. poj 2253 floyd最短路

    题目链接 : http://poj.org/problem?id=2253: 思路:这个题主要是理解了意思就行,题目意思是有两只青蛙和若干块石头,现在已知这些东西的坐标,两只青蛙A坐标和青蛙B坐标是第 ...

  3. docker-compose.yml 配置文件详解及项目发布

    摘自:https://blog.csdn.net/qq_36148847/article/details/79427878 docker部署tomcat项目 1.上传war包2.制作镜像 Docker ...

  4. JS与CSS那些特别小的知识点区别

    1:target与currentTarget的区别 currentTarget指向的事件绑定的元素,target指向的是你点击的元素 2:attr与jprop在jQuery在API当中的区别 2.1: ...

  5. [CF1132G]Greedy Subsequences

    [CF1132G]Greedy Subsequences 题目大意: 定义一个序列的最长贪心严格上升子序列为:任意选择第一个元素后,每次选择右侧第一个大于它的元素,直到不能选为止. 给定一个长度为\( ...

  6. Python网络编程基础pdf

    Python网络编程基础(高清版)PDF 百度网盘 链接:https://pan.baidu.com/s/1VGwGtMSZbE0bSZe-MBl6qA 提取码:mert 复制这段内容后打开百度网盘手 ...

  7. lua 语言基础

    1.数据类型: string(字符串) ·运算符“+.-.*./”等操作字符串,lua会尝试讲字符串转换为数字后操作: ·字符串连接用“..”运算符 ·用“#”来计算字符串的长度(放在字符串前面) · ...

  8. css3_transition: 体验好的过渡效果。附 好看的按钮

    利用css的transition属性详解,上图就是利用transition效果做的一个按钮. transition属性://举例子:transition:all 1s ease;transition: ...

  9. 《JavaScript高级程序设计(第3版)》阅读总结记录第二章之在HTML中使用JavaScript

    本章目录: 2.1 <script> 元素 2.1.1 标签的位置 2.1.2 延迟脚本 2.1.3 异步脚本 2.1.4 在XHTML 中的用法 2.1.5 不推荐使用的语法 2.2 嵌 ...

  10. php基础-mysqli

    基本八个步骤 //连接数据库 $link = mysqli_connect('localhost', 'root', ''); //判断是否连接成功 if (!$link) { exit('数据库连接 ...