Pytorch中数据集读取
  在机器学习中,有很多形式的数据,我们就以最常用的几种来看:
  在Pytorch中,他自带了很多数据集,比如MNIST、CIFAR10等,这些自带的数据集获得和读取十分简便:
  

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
train_data = torchvision.datasets.MNIST(
  root='./mnist/', # 数据集存放的位置,他会先查找,如果该地没有对应数据集,就下载(不连外网的话推荐从网上下载好后直接放到对应目录)
            #比如这里root是"./mnist/" 那么下载好的数据放入./mnist/raw/下
  train=True, # 表示这是训练集,如果是测试集就改为false
  transform=torchvision.transforms.ToTensor(), # 表示数据转换的格式
  download=True,
)

以上就获得了对应的数据集,接下来就是读取:
  

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #使用自定义好的DataLoader函数来进行
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # 这里就可以遍历数据集,进行训练了
XXXX
XXXX

可以看到,这样获得数据集确实方便,但是也有缺点:

  1、只能获得固定的数据集,对于自己的特有的数据集,无法读取。

  2、如果要对这些数据集处理,没有办法做到。

因此,我们需要一种方法来读取本地的数据集:

先看看如何把从网上下载的文件转换为图片存储:


import os
from skimage import io
import torchvision.datasets.mnist as mnist
import numpy

root = "data/MNIST/raw/"  #自己对应数据集的目录
#获得对应的训练集的数据和标签
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
#获得对应的测试机的数据和标签
test_set = (
mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
) print("train set:", train_set[0].size())
print("test set:", test_set[0].size())
#将数据转换为图片格式
def convert_to_img(train=True):
if (train):
#如果是训练集的话 就放到训练集中
f = open(root + 'train.txt', 'w')
#打开对应文件
data_path = root + '/train/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
# 拼合出图片路径
      img_path = data_path + str(i) + '.jpg'
     #将img这个获取的数据转换为numpy后存入,而因为文件存入后打开格式为.jpg,所以自动变成了图片 io.imsave(img_path, img.numpy())
      #存入标签
int_label = str(label).replace('tensor(', '')
int_label = int_label.replace(')', '')
f.write(img_path + ' ' + str(int_label) + '\n')
f.close()
else:
f = open(root + 'test.txt', 'w')
data_path = root + '/test/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
img_path = data_path + str(i) + '.jpg'
io.imsave(img_path, img.numpy())
int_label = str(label).replace('tensor(', '')
int_label = int_label.replace(')', '')
f.write(img_path + ' ' + str(int_label) + '\n')
f.close()
convert_to_img(True)
convert_to_img(False)

这样以后,得到的效果是:

    

在train.txt中,存放着文件的路径和标签

然后在文件路径中,有如图的图片:

    

数据转换成图片格式了,那接下来看看如何将图片读取:

读取数据的重点就是重写torch.utils.data的Dataset方法,其中有三个重要的方法:__init__,__getitem__,__len__,这三个方法分别表示:对数据集的初始化,在循环的时候获得数据,还有数据的长度。

我们就用已有的图片数据集,来进行CNN的训练:

首先要写一个自己的数据集加载的类:

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import FileLoader
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
#定义对数据集的加载方式
def default_loader(path):
return Image.open(path)
#继承自data.Dataset,完成对方法的重写
class MyImageFloder(data.Dataset):
def __init__(self,FileName,transform = None,target_transform = None,loader = default_loader):
#获得对应的数据位置和标签
     FilePlaces,LabelSet = FileLoader.filePlace_loader(FileName)
# 对标签进行修改,因为读进来变成了str类型,要修改成long类型
LabelSet = [np.long(i) for i in LabelSet]
LabelSet = torch.Tensor(LabelSet).long() self.imgs_place = FilePlaces
self.LabelSet = LabelSet
self.transform = transform
self.target_transform = target_transform
self.loader = loader
# 这里是对数据进行读取 使用之前定义的loader方法来执行
def __getitem__(self, item):
img_place = self.imgs_place[item]
label = self.LabelSet[item]
img = self.loader(img_place)
if self.transform is not None:
img = self.transform(img) return img,label def __len__(self):
return len(self.imgs_place)

这里的重点就是,一定看好了读取以后的格式和维度,还有类型,在实际使用的时候,经常报错,要根据提示来修改对应的代码!

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import FileLoader
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import MyImageLoader
batch_size = 64 mytransform = transforms.Compose([
transforms.ToTensor()
]
)
batch_size = 64 train_loader = DataLoader(MyImageLoader.MyImageFloder(FileName='data/MNIST/raw/train.txt',transform=mytransform), batch_size=batch_size,
shuffle=True) test_loader = DataLoader(MyImageLoader.MyImageFloder(FileName='data/MNIST/raw/test.txt',transform=mytransform), batch_size=batch_size,
shuffle=True) class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 输入1通道,输出10通道,kernel 5*5
self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv3 = nn.Conv2d(20, 40, 3) self.mp = nn.MaxPool2d(2)
# fully connect
self.fc = nn.Linear(40, 10)#(in_features, out_features) def forward(self, x):
# in_size = 64
in_size = x.size(0) # one batch 此时的x是包含batchsize维度为4的tensor,即(batchsize,channels,x,y),x.size(0)指batchsize的值 把batchsize的值作为网络的in_size
# x: 64*1*28*28
x = F.relu(self.mp(self.conv1(x)))
# x: 64*10*12*12 feature map =[(28-4)/2]^2=12*12
x = F.relu(self.mp(self.conv2(x)))
# x: 64*20*4*4
x = F.relu(self.mp(self.conv3(x))) x = x.view(in_size, -1) # flatten the tensor 相当于resharp
# print(x.size())
# x: 64*320
x = self.fc(x)
# x:64*10
# print(x.size())
return F.log_softmax(x) #64*10 model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) losses =[]
test_losses = []
test_acces = [] def train(epoch):
loss=0; for batch_idx, (data, target) in enumerate(train_loader):#batch_idx是enumerate()函数自带的索引,从0开始
# data.size():[64, 1, 28, 28]
# target.size():[64] output = model(data)
#output:64*10 # target = [np.long(i) for i in target]
# target = torch.Tensor(target).long() loss = F.nll_loss(output, target) if batch_idx % 200 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) optimizer.zero_grad() # 所有参数的梯度清零
loss.backward() #即反向传播求梯度
optimizer.step() #调用optimizer进行梯度下降更新参数
# 每次训练完一整轮,记录一次
losses.append(loss.item()) def test():
test_loss = 0
correct = 0
for data, target in test_loader:
# target = [np.long(i) for i in target]
# target = torch.Tensor(target).long()
with torch.no_grad():
data, target = Variable(data), Variable(target)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) test_losses.append(test_loss)
print(correct.item())
test_acces.append(correct.item() / len(test_loader.dataset)) for epoch in range(1, 5):
train(epoch)
test() plt.figure(22) x_loss = list(range(len(losses)))
x_acc = list(range(len(test_acces))) plt.subplot(221)
plt.title('train loss ')
plt.plot(x_loss, losses) plt.subplot(222)
plt.title('test loss')
plt.plot(x_loss, test_losses) plt.subplot(212)
plt.plot(x_acc, test_acces)
plt.title('test acc')
plt.show()

这里直接套用CNN的代码即可。

关于此问题比较好的方法可以参考:https://blog.csdn.net/qq_36852276/article/details/94588656

                https://blog.csdn.net/sjtuxx_lee/article/details/83031718

Pytorch数据集读取的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. pytorch实现花朵数据集读取

    import os from PIL import Image from torch.utils import data import numpy as np from torchvision imp ...

  3. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  4. Pytorch数据读取框架

    训练一个模型需要有一个数据库,一个网络,一个优化函数.数据读取是训练的第一步,以下是pytorch数据输入框架. 1)实例化一个数据库 假设我们已经定义了一个FaceLandmarksDataset数 ...

  5. Pytorch数据读取详解

    原文:http://studyai.com/article/11efc2bf#%E9%87%87%E6%A0%B7%E5%99%A8%20Sampler%20&%20BatchSampler ...

  6. Pytorch数据读取与预处理实现与探索

    在炼丹时,数据的读取与预处理是关键一步.不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的.Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录 ...

  7. CIFAR-10数据集读取

    参考:https://jingyan.baidu.com/article/656db9183296c7e381249cf4.html 1.使用读取方式pickle def unpickle(file) ...

  8. 深度学习(tensorflow) —— 自己数据集读取opencv

    先来看一下我们的目录: dataset1 和creat_dataset.py 属于同一目录 mergeImg1 和mergeImg2 为Dataset1的两子目录(两类为例子)目录中存储图像等文件 核 ...

  9. 用于DataLoader的pytorch数据集

    暂时介绍 image-mask型数据集, 以人手分割数据集 EGTEA Gaze+ 为例. 准备数据文件夹 需要将Image和Mask分开存放, 对应文件的文件名必须保持一致. 提醒: Mask 图像 ...

随机推荐

  1. 数据挖掘入门系列教程(十点五)之DNN介绍及公式推导

    深度神经网络(DNN,Deep Neural Networks)简介 首先让我们先回想起在之前博客(数据挖掘入门系列教程(七点五)之神经网络介绍)中介绍的神经网络:为了解决M-P模型中无法处理XOR等 ...

  2. Apache漏洞利用与安全加固实例分析

    Apache 作为Web应用的载体,一旦出现安全问题,那么运行在其上的Web应用的安全也无法得到保障,所以,研究Apache的漏洞与安全性非常有意义.本文将结合实例来谈谈针对Apache的漏洞利用和安 ...

  3. sql语句------合并结果集

    select id,max(val) FROM (select id,a 列名 val from 表名unionselect id,b 列名 val from 表名unionselect id,c 列 ...

  4. Redis持久化存储(三)

    redis高级特性-发布订阅消息服务功能 Pub/Sub 订阅,取消订阅和发布实现了发布/订阅消息范式(引自wikipedia),发送者(发布者)不是计划发送消息给特定的接收者(订阅者).而是发布的消 ...

  5. JavaScript HTMlL DOM对象(上)

    Dom:document.相当于把所有的html文件,转换成了文档对象. 之前说过:html-裸体的人:css-穿上衣服:js-让人动起来. 让人动起来,就得先找到他,再修改它内容或属性. 找到标签 ...

  6. JAVA进程CPU高的解决方法

    无限循环的while会导致CPU使用率飙升吗?经常使用Young GC会导致CPU占用率飙升吗?具有大量线程的应用程序的CPU使用率是否较高?CPU使用率高的应用程序的线程数是多少?处于BLOCKED ...

  7. Vue Router路由守卫妙用:异步获取数据成功后再进行路由跳转并传递数据,失败则不进行跳转

    问题引入 试想这样一个业务场景: 在用户输入数据,点击提交按钮后,这时发起了ajax请求,如果请求成功, 则跳转到详情页面并展示详情数据,失败则不跳转到详情页面,只是在当前页面给出错误消息. 难点所在 ...

  8. 【React踩坑记三】React项目报错Can't perform a React state update on an unmounted component

    意思为:我们不能在组件销毁后设置state,防止出现内存泄漏的情况 分析出现问题的原因: 我这里在组件加载完成的钩子函数里调用了一个EventBus的异步方法,如果监听到异步方法,则会更新state中 ...

  9. MYSQl 全表扫描以及查询性能

    MYSQl 全表扫描以及查询性能 -- 本文章仅用于学习,记录 一. Mysql在一些情况下全表检索比索引查询更快: 1.表格数据很少,使用全表检索会比使用索引检索更快.一般当表格总数据小于10行并且 ...

  10. INTERVIEW #3

    菊厂的面试本来没打算记录,因为当时投的是非技术岗(技术支持).为了全面,就寥做记录. 菊厂的面试因为有口头保密协议,所以不能透露具体题目. 0 群面 简历通过筛选后,会有短信通知去面试. 非技术岗第一 ...