PyTorch 数据集类 和 数据加载类 的一些尝试
最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实验,小尝试。
下面给出一个常用的数据类使用方式:
def data_tf(x):
x = np.array(x, dtype='float32') / 255 # 将数据变到 0 ~ 1 之间
x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到
x = x.reshape((-1,)) # 拉平
x = torch.from_numpy(x)
return x from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据
train_set = MNIST('./data', train=True, transform=data_tf, download=True) # 载入数据集,申明定义的数据变换
test_set = MNIST('./data', train=False, transform=data_tf, download=True)
其中, data_tf 并不是必须要有的,比如:
from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据
train_set = MNIST('./data', train=True, download=True) # 载入数据集,申明定义的数据变换
test_set = MNIST('./data', train=False, download=True)
这里面的MNIST类是框架自带的,可以自动下载MNIST数据库, ./data 是指将下载的数据集存放在当前目录下的哪个目录下, train 这个属性 True时 则在 ./data文件夹下面在建立一个 train的文件夹然后把下载的数据存放在其中, 当train属性是False的时候则把下载的数据放在 test文件夹下面。
划线部分是老版本的PyTorch的处理方式, 最近试了一下最新版本 PyTorch 1.0 , train为True的时候是把数据放在 ./data/processed 文件夹下面, 命名为training.pt , 为False 的时候则放在 ./data/processed 文件夹下面, 命名为test.pt 。
这时候就出现了一个问题, 如果你使用的数据集不是框架自带的那么如何使用数据类呢,这个时候就要使用 pytorch 中的 Dataset 类了。
from torch.utils.data import Dataset
我们需要重写 Dataset类, 需要实现的方法为 __len__ 和 __getitem__ 这两个内置方法, 这里可以看出其思想就是要重写的类需要支持按照索引查找的方法。
这里我们还是举个例子:
从这个例子可以看出 mydataset就是我们自定义的 myDataset 类生成的自定义数据类对象。我们可以在myDataset类中自定义一些方法来对需要的数据进行处理。
为说明该问题另附加一个例子:
from torch.utils.data import Dataset #需要在pytorch中使用的数据
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]] class myDataset(Dataset):
def __init__(self, indata):
self.data=indata
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx] mydataset=myDataset(data)
那么又来了一个问题,我们不重写 Dataset类的话可不可以呢, 经过尝试发现还真可以,如下:
又如:
由这个例子可以看出数据类对象可以不重写Dataset类, 只要具备 __len__ __getitem__ 方法就可以。而且从这个例子我们可以看出 DataLoader 是一个迭代器, 如果shuffle 设置为 True 那么在每次迭代之前都会重新排序。
同时由上面两个例子可以看出 DataLoader类会把传入的数据集合中的数据转化为 torch.tensor 类型, 当然是采用默认的 DataLoader类中转化函数 transform的情况下。
这也就是说 DataLoader 默认的转化函数 transform操作为 传入的[ [x, x, x], [y, y, y] ] 输出的是 [ tensor([x, x, x]), tensor([y, y, y]) ] ,
传入的是 tensor([ [x, x, x], [y, y, y] ]) 输出的是 tensor([ tensor([x, x, x]), tensor([y, y, y]) ] ), (这个例子是在 batch_size=2 的情况)。
综上,可知 其实 Dataset类, 和 DataLoader类其实在pytorch 计算过程中都不是一定要有的, 其中Dataset类是起一个规范作用,意义在于要人们对不同的类型数据做一些初步的调整,使其支持按照索引读取,以使其可以在 DataLoader中使用。
DataLoader 是一个迭代器, 可以方便的通过设置 batch_size 来实现 batch过程,transform则是对数据的一些处理。
---------------------------------------------------------------------------------------------------
上述内容更正:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader #需要在pytorch中使用的数据
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]] class myDataset(Dataset):
def __init__(self, indata):
self.data=indata
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx] mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True) print("上文的错误操作:") for i in train_data:
print(i)
print('-'*30)
print('again')
for i in train_data:
print(i)
print('-'*30) ######################################### data=np.array(data)
data=torch.from_numpy(data) mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True) print("修正后的正确操作:") for i in train_data:
print(i)
print('-'*30)
print('again')
for i in train_data:
print(i)
print('-'*30)
(base) devil@devilmaycry:/tmp$ python w.py
上文的错误操作:
[tensor([3.1000, 4.1000, 5.1000], dtype=torch.float64), tensor([3.2000, 4.2000, 5.2000], dtype=torch.float64), tensor([3.3000, 4.3000, 5.3000], dtype=torch.float64)]
------------------------------
[tensor([1.1000, 2.1000], dtype=torch.float64), tensor([1.2000, 2.2000], dtype=torch.float64), tensor([1.3000, 2.3000], dtype=torch.float64)]
------------------------------
again
[tensor([3.1000, 5.1000, 1.1000], dtype=torch.float64), tensor([3.2000, 5.2000, 1.2000], dtype=torch.float64), tensor([3.3000, 5.3000, 1.3000], dtype=torch.float64)]
------------------------------
[tensor([2.1000, 4.1000], dtype=torch.float64), tensor([2.2000, 4.2000], dtype=torch.float64), tensor([2.3000, 4.3000], dtype=torch.float64)] ------------------------------ 修正后的正确操作:
tensor([[2.1000, 2.2000, 2.3000],
[1.1000, 1.2000, 1.3000],
[3.1000, 3.2000, 3.3000]], dtype=torch.float64)
------------------------------
tensor([[4.1000, 4.2000, 4.3000],
[5.1000, 5.2000, 5.3000]], dtype=torch.float64)
------------------------------
again
tensor([[5.1000, 5.2000, 5.3000],
[4.1000, 4.2000, 4.3000],
[3.1000, 3.2000, 3.3000]], dtype=torch.float64)
------------------------------
tensor([[2.1000, 2.2000, 2.3000],
[1.1000, 1.2000, 1.3000]], dtype=torch.float64)
------------------------------
可以看出 传入到 Dataset 中的对象必须是 torch 类型的 tensor 类型, 如果传入的是list则会得出错误结果。
-----------------------------------------------------------------------------------------------------
补充:
之所以发现上面的这个错误,是因为发现了下面的代码:
import numpy as np
from torchvision.datasets import mnist # 导入 pytorch 内置的 mnist 数据
from torch.utils.data import DataLoader
#from torch.utils.data import Dataset def data_tf(x):
x = np.array(x, dtype='float32') / 255
x = (x - 0.5) / 0.5 # 数据预处理,标准化
x = x.reshape((-1,)) # 拉平
x = torch.from_numpy(x)
return x #Dataset
# 重新载入数据集,申明定义的数据变换
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
从上面的 data_tf 函数中我们发现, Dataset对象返回的是 torch 的 tensor 对象。
PyTorch 数据集类 和 数据加载类 的一些尝试的更多相关文章
- arcgis python 使用光标和内存中的要素类将数据加载到要素集 学习:http://zhihu.esrichina.com.cn/article/634
学习:http://zhihu.esrichina.com.cn/article/634使用光标和内存中的要素类将数据加载到要素集 import arcpy arcpy.env.overwriteOu ...
- java动态加载类和静态加载类笔记
JAVA中的静态加载类是编译时刻加载类 动态加载类指的是运行时刻加载类 二者有什么区别呢 举一个例子 现在我创建了一个类 实现的功能假设为通过传入的参数调用具体的类和方法 class offic ...
- Unity3d通用工具类之数据配置加载类
今天,我们来讲讲游戏中的数据配置加载. 什么是游戏数据加载呢?一般来说游戏中会有场景地图. 按照国际惯例,先贴一张游戏场景的地图: 在这张地图上,我们可以看到有很多正六边形,正六边形上有树木.岩石等. ...
- [javaSE] 反射-动态加载类
Class.forName(“类的全称”) ①不仅表示了类的类类型,还代表了动态加载类 ②请大家区分编译,运行 ③编译时刻加载类是静态加载类,运行时刻加载类是动态加载类 Ⅰ所有的new对象都是静态加载 ...
- Java 编程下使用 Class.forName() 加载类
在一些应用中,无法事先知道使用者将加载什么类,而必须让使用者指定类名称以加载类,可以使用 Class 的静态 forName() 方法实现动态加载类.下面的范例让你可以指定类名称来获得类的相关信息. ...
- Java 编程下使用 Class.forName() 加载类【转】
在一些应用中,无法事先知道使用者将加载什么类,而必须让使用者指定类名称以加载类,可以使用 Class 的静态 forName() 方法实现动态加载类.下面的范例让你可以指定类名称来获得类的相关信息. ...
- java中class.forName和classLoader加载类的区分
java中class.forName和classLoader都可用来对类进行加载.前者除了将类的.class文件加载到jvm中之外,还会对类进行解释,执行类中的static块.而classLoade ...
- 反射01 Class类的使用、动态加载类、类类型说明、获取类的信息
0 Java反射机制 反射(Reflection)是 Java 的高级特性之一,是框架实现的基础. 0.1 定义 Java 反射机制是在运行状态中,对于任意一个类,都能够知道这个类的所有属性和方法:对 ...
- java反射机制与动态加载类
什么是java反射机制? 1.当程序运行时,允许改变程序结构或变量类型,这种语言称为动态语言.我们认为java并不是动态语言,但是它却有一个非常突出的动态相关机制,俗称:反射. IT行业里这么说,没有 ...
随机推荐
- Python -- Json 数据编码及解析
Python -- Json 数据编码及解析 Json 简单介绍 JSON: JavaScript Object Notation(JavaScript 对象表示法) JSON 是存储和交换文本 ...
- 机器学习 Numpy库入门
2017-06-28 13:56:25 Numpy 提供了一个强大的N维数组对象ndarray,提供了线性代数,傅里叶变换和随机数生成等的基本功能,可以说Numpy是Scipy,Pandas等科学计算 ...
- 解决SVN图标不显示问题
Windows最多只允许15个覆盖图标,它自己又用了几个,结果给用户用的就11个左右了,如果你安装了其他网盘,那可用的就更少了. 解决方法: 1.在运行里输入regedit进入注册表 2.依次打开HK ...
- shell中引号的妙用
#!/bin/bashfile=('leon 01.cap' leon-02.cap nicky-01.cap whoareu-01.cap 8dbb-01.cap)dict=(simple.txt ...
- Educational Codeforces Round 47 (Rated for Div. 2)G. Allowed Letters 网络流
题意:给你一个字符串,和每个位置可能的字符(没有就可以放任意字符)要求一个排列使得每个位置的字符在可能的字符中,求字典序最小的那个 题解:很容易判断有没有解,建6个点表示从a-f,和源点连边,容量为原 ...
- mxnet(gluon) 实现DQN简单小例子
参考文献 莫凡系列课程视频 增强学习入门之Q-Learning 关于增强学习的基本知识可以参考第二个链接,讲的挺有意思的.DQN的东西可以看第一个链接相关视频.课程中实现了Tensorflow和pyt ...
- nyoj1248(阅读理解???)
海岛争霸 时间限制:1000 ms | 内存限制:65535 KB 难度:3 描述 神秘的海洋,惊险的探险之路,打捞海底宝藏,激烈的海战,海盗劫富等等.加勒比海盗,你知道吧?杰克船长驾驶着自己 ...
- ps -ef |grep xxx 输出的具体含义
ps:将某个进程显示出来 -A 显示所有程序. -e 此参数的效果和指定"A"参数相同. -f 显示UID,PPIP,C与STIME栏位. grep命令是查找 中间的|是管道命令 ...
- java并发带返回结果的批量任务执行(CompletionService:Executor + BlockingQueue)
转载:http://www.it165.net/pro/html/201405/14551.html 一般情况下,我们使用Runnable作为基本的任务表示形式,但是Runnable是一种有很大局限的 ...
- css实现椭圆
先实现个简单点的,用css实现一个圆,ok,直接上代码: .circle{ width: 100px; height:100px; background: red; border-radius: 50 ...