[转载]pytorch自定义数据集
为什么要定义Datasets:
PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。
如何自定义Datasets
下面是一个自定义Datasets的框架:
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):
class MNIST(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set if download:
self.download() if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None:
img = self.transform(img) if self.target_transform is not None:
target = self.target_transform(target) return img, target def __len__(self):
if self.train:
return 60000
else:
return 10000
[转载]pytorch自定义数据集的更多相关文章
- PyTorch 自定义数据集
准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...
- Pytorch划分数据集的方法
之前用过sklearn提供的划分数据集的函数,觉得超级方便.但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是"pytorch split dat ...
- pytorch加载语音类自定义数据集
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...
- MMDetection 快速开始,训练自定义数据集
本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- torch_13_自定义数据集实战
1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...
- [转载]PyTorch上的contiguous
[转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...
- [转载]PyTorch中permute的用法
[转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...
- [转载]Pytorch详解NLLLoss和CrossEntropyLoss
[转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...
随机推荐
- 具有避障和寻线功能的Arduino小车
标签: Arduino 乐高 机器人 创客对于成年人来说,多半是科技娱乐,或者是一种是一种向往科技的人生态度,总是希望自己不仅可以看到或者听到科技的资讯,还希望能够亲身制作科技玩意,从而更好地体 ...
- 配置php的curl模块问题
问题 checking for cURL in default path... not foundconfigure: error: Please reinstall the libcurl dist ...
- iOS 给Main.storyboard 添加button 事件《转》
XCODE中使用Main.Storyboard拉入控件并实现事件(Swift语言) 如何在XCODE中的Main.Storyboard内拉入控件并实现一个简单的效果呢?本人由于刚接触Swift语言 ...
- Spring 已看 没用
注解 @Autwired 依赖注入 作用: 自动按照类型注入.当使用注解注入属性时,set方法可以省略.它只能注入其他bean类型.当有多个类型匹配时,使用要注入的对象变量名称作为bean的id,在 ...
- c语言学习笔记 scanf和printf格式的问题
int a =0; int b =0; scanf("%d %d",&,&b); 上面这种和下面这种哪种对? int a =0; int b =0; scanf(& ...
- 505C Mr. Kitayuta, the Treasure Hunter
传送门 题目大意 一共有30000个位置,从第0个位置开始走,第一次走k步,对于每一次走步,可以走上一次的ki+1 ,ki ,ki-1步数(必须大于等于1),每个岛上有value,求最大能得到的val ...
- Spring学习大纲
1.BeanFactory 和 FactoryBean? 2.Spring IOC 的理解,其初始化过程? 3.BeanFactory 和 ApplicationContext? 4.Spring B ...
- 解决ScrollView嵌套viewpager滑动事件冲突问题
重写ScrollView 第一种方案能解决viewpager的滑动问题,但是scrollView有时会滑不动 public class VerticalScrollView extends Scrol ...
- git命令(二)
2.查看.添加.提交.删除.找回,重置修改文件 3.查看文件diff git diff <file> # 比较当前文件和暂存区文件差异 git diff git diff <$id1 ...
- PAT甲 1095 解码PAT准考证/1153 Decode Registration Card of PAT(优化技巧)
1095 解码PAT准考证/1153 Decode Registration Card of PAT(25 分) PAT 准考证号由 4 部分组成: 第 1 位是级别,即 T 代表顶级:A 代表甲级: ...