PyTorch 之 Datasets
实现一个定制的 Dataset 类
Dataset 类是 PyTorch 图像数据集中最为重要的一个类,也是 PyTorch 中所有数据集加载类中应该继承的父类。其中,父类的两个私有成员函数必须被重载。
- getitem(self, index) # 支持数据集索引的函数
 - len(self) # 返回数据集的大小
 
Datasets 的框架:
class CustomDataset(data.Dataset): # 需要继承 data.Dataset
    def __init__(self):
        # TODO
        # Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)
        # 2. 预处理读取的数据(例:torchvision.Transform)
        # 3. 返回数据对(例:图像和对应标签)
        pass
    def __len__(self):
        # TODO
        # You should change 0 to the total size of your dataset.
        return 0
举例:
class MyDataset(Dataset):
    """
     root: 图像存放地址根路径
     augment:是否需要图像增强
    """
    def __init__(self, root, augment=None):
        # 这个 list 存放所有图像的地址
        self.image_files = np.array([
            x.path for x in os.scandir(root)
            if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")
        ])
        self.augment = augment
    def __getitem__(self, index):
        if self.augment:
            image = open_image(self.image_files[index])   # 这里的 open_image 是读取图像的函数,可以用 PIL 或者 OpenCV 等库进行读取
            image = self.augment(image)	  # 这里对图像进行了数据增强
            return to_tensor(image)	      # PyTorch 中得到的图像必须是 tensor
        else:
            image = open_image(self.image_files[index])
            return to_tensor(image)
下面是官方 MNIST 的例子:
class MNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
    class_to_idx = {_class: i for i, _class in enumerate(classes)}
    @property
    def targets(self):
        if self.train:
            return self.train_labels
        else:
            return self.test_labels
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip
        if self._check_exists():
            return
        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise
        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)
        # process and save as torch files
        print('Processing...')
        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)
        print('Done!')
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str
												
											PyTorch 之 Datasets的更多相关文章
- [深度学习] pytorch利用Datasets和DataLoader读取数据
		
本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...
 - (转)Awesome PyTorch List
		
Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...
 - 如何使用Pytorch迅速实现Mnist数据及分类器
		
一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...
 - pytorch实现VAE
		
一.VAE的具体结构 二.VAE的pytorch实现 1加载并规范化MNIST import相关类: from __future__ import print_function import argp ...
 - PyTorch教程之Training a classifier
		
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新. 接下来的问题就是: 一.What about data? 通常处理图像.文本.音频或视频数据时,可以使用标准的python包将数据加载 ...
 - PyTorch官方中文文档:PyTorch中文文档
		
PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...
 - 基于pytorch的电影推荐系统
		
本文介绍一个基于pytorch的电影推荐系统. 代码移植自https://github.com/chengstone/movie_recommender. 原作者用了tf1.0实现了这个基于movie ...
 - [深度应用]·实战掌握PyTorch图片分类简明教程
		
[深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...
 - pytorch识别CIFAR10:训练ResNet-34(自定义transform,动态调整学习率,准确率提升到94.33%)
		
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 前面通过数据增强,ResNet-34残差网络识别CIFAR10,准确率达到了92.6. 这里对训练过程 ...
 
随机推荐
- Mac下安装npm全局包提示权限不够
			
Mac OS下安装npm的全局包,总是出现如下提示Missing write access,需要提升权限才能继续. npm WARN checkPermissions Missing write ac ...
 - Java有参构造方法和成员方法
			
Java面向对象基础就包括有参构造方法和成员方法 无参构造声明方式为 class Teacher{ //无参构造方法,一般用在初始化变量 public Teacher(){ } } 有参构造声明为 c ...
 - Windows 下 pycharm 创建Django 项目【用虚拟环境的解释器】
			
1. 背景 我在 Windows 下的 pycharm 直接创建 全新 Django 项目 会 pip 和其他报错 ,暂时解决不了,另外后续的多个项目只需要一套python 环境, 所以可以 ...
 - django-图形验证码(django-simple-captcha)
			
在网站开发的登录页面中,经常会需要使用到图形验证码来验证.在Django中,django-simple-captcha库包提供了图形验证码的使用. django-simple-captcha的安装 p ...
 - Jenkins + pipeline + Git + PHP (九)
			
一.准备环境介绍 192.168.5.71 # gitlab 仓库IP 192.168.5.72 # 开发环境,用于提交代码等 192.168.5.150 # www.leon.com 运行wordp ...
 - C++学习视频和资料
			
我在学习c++时,比较迷茫,而且当时学完c++primer时不知道该学习什么, 犹豫了好久,最后找到了一些关于c++学习路线的视频,包含源代码,我感觉还不错,分享给大家. 下载地址 https://d ...
 - 201871010126 王亚涛 《面向对象程序设计(Java)》第七周实验总结
			
---恢复内容开始--- 项目 内容 这个作业属于哪个课程 https://www.cnblogs.com/nwnu-daizh/ 这个作业的要求在哪里 https://www.cnblogs.com ...
 - 【Spring AOP】通知(五)
			
一.通知介绍 1. 前置通知(Before) 在目标方法执行之前执行的通知. 前置通知方法,可以没有参数,也可以额外接收一个JoinPoint,Spring会自动将该对象传入,代表当前的连接点,通过该 ...
 - Eclipse安装svn插件(五)
			
一.在线安装 1. 点击 Help --> Install New Software... 2. 在弹出的窗口中点击add按钮,输入Name(任意)和Location(插件的URL),点击OK ...
 - vscode使用插件来添加文件说明和函数说明——42header——psioniq File Header——koroFileHeader
			
安装号以后,设置快捷键如下: 同时需要根据自己的需要的修改json文件 // 文件头部注释 "fileheader.customMade": { "Description ...