By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn expects all the images in a batch to have the same size because it uses torch.stack() to pack the images. If the images provided by Dataset have variable size, you have to provide your custom collate_fn. A simple example is shown below:

 # a simple custom collate function, just to show the idea

 # `batch` is a list of tuple where first element is image tensor and

 # second element is corresponding label

 def my_collate(batch):
data = [item[0] for item in batch] # just form a list of tensor target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]

Reference:   Writing Your Own Custom Dataset for Classification in PyTorch

By default, torch stacks the input image to from a tensor of size N*C*H*W, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn which is used to pack a batch of images.

For image classification, the input to collate_fn is a list of with size batch_size. Each element is a tuple where the first element is the input image(a torch.FloatTensor) and the second element is the image label which is simply an int. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor. Then we put the image list and the label tensor into a list and return the result.

here is a very simple snippet to demonstrate how to write a custom collate_fn:

 import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt # a simple custom collate function, just to show the idea
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target] def show_image_batch(img_list, title=None):
num = len(img_list)
fig = plt.figure()
for i in range(num):
ax = fig.add_subplot(1, num, i+1)
ax.imshow(img_list[i].numpy().transpose([1,2,0]))
ax.set_title(title[i]) plt.show() # do not do randomCrop to show that the custom collate_fn can handle images of different size
train_transforms = transforms.Compose([transforms.Scale(size = 224),
transforms.ToTensor(),
]) # change root to valid dir in your system, see ImageFolder documentation for more info
train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
transform=train_transforms) trainset = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
collate_fn=my_collate, # use custom collate function here
pin_memory=True) trainiter = iter(trainset)
imgs, labels = trainiter.next() # print(type(imgs), type(labels))
show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])

Reference:    How to create a dataloader with variable-size input

Dataloader的测试用例:

 import torch
import torch.utils.data as Data
import numpy as np test = np.array([0,1,2,3,4,5,6,7,8,9,10,11]) inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)])) torch_dataset = Data.TensorDataset(inputing,target)
batch = 3 loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch, # 批大小
# 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
) for (i,j) in loader:
print(i)
print(j)

Reference: DataLoader的collate_fn参数

pytorch 读取变长数据

https://zhuanlan.zhihu.com/p/60129684

Pytorch collate_fn用法的更多相关文章

  1. pytorch faster_rcnn

    代码地址:https://github.com/jwyang/faster-rcnn.pytorch 1.fasterRCNN.train():这个不是让网络进行训练,而是让module in tra ...

  2. Transformers 简介(下)

    作者|huggingface 编译|VK 来源|Github Transformers是TensorFlow 2.0和PyTorch的最新自然语言处理库 Transformers(以前称为pytorc ...

  3. 深度学习与CV教程(8) | 常见深度学习框架介绍

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

  4. Pytorch 一些函数用法

    PyTorch中view的用法:https://blog.csdn.net/york1996/article/details/81949843 max用法 import torch d=torch.T ...

  5. 关于Pytorch的二维tensor的gather和scatter_操作用法分析

    看得不明不白(我在下一篇中写了如何理解gather的用法) gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下: out[i][j] = input[index[i][j]] ...

  6. [转载]PyTorch中permute的用法

    [转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...

  7. Pytorch中randn和rand函数的用法

    Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...

  8. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  9. PyTorch中view的用法

    相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...

随机推荐

  1. Mac技巧-如何切换至 Mac 地图应用的卫星视图模式

    如何切换至Mac地图应用的卫星视图模式?很多刚接触MAC电脑的小伙伴并不是很清楚,今天MACW小编就教教大家切换至 Mac 地图应用的卫星视图模式该怎么做.原文:https://www.macw.co ...

  2. PAT Basic 1017 A除以B (20) [数学问题-⼤整数运算]

    题目 本题要求计算A/B,其中A是不超过1000位的正整数,B是1位正整数.你需要输出商数Q和余数R,使得A = B * Q + R成⽴. 输⼊格式: 输⼊在1⾏中依次给出A和B,中间以1空格分隔. ...

  3. pycharm运行过程中,出现python已停止工作的对话框的解决办法

    在Windows7的情况下,在运行中输入“Regedit”并执行,使用注册表编辑器. 依次定位到HKEY_CURRENT_USER\Software\Microsoft\Windows\Windows ...

  4. 估计量|估计值|矩估计|最大似然估计|无偏性|无偏化|有效性|置信区间|枢轴量|似然函数|伯努利大数定理|t分布|单侧置信区间|抽样函数|

    第二章 置信区间估计 估计量和估计值的写法? 估计值希腊字母上边有一个hat 点估计中矩估计的原理? 用样本矩来估计总体矩,用样本矩的连续函数来估计总体矩的连续函数,这种估计法称为矩估计法.Eg:如果 ...

  5. keras字符编码

    https://www.jianshu.com/p/258a21ae0390https://blog.csdn.net/apengpengpeng/article/details/80866034#- ...

  6. C#匿名委托,匿名函数,lambda表达式

    一.类型.变量.实例之间的关系. 类型>变量>实例 类型可以创建变量,实体类可以创建实例,实例可以存储在变量里. 二.委托使用过程: 1.定义委托(写好签名): 2.创建委托变量: 3.给 ...

  7. 学习ECC及Openssl下ECC生成密钥的部分源代码心得

    一.ECC的简介 椭圆曲线算法可以看作是定义在特殊集合下数的运算,满足一定的规则.椭圆曲线在如下两个域中定义:Fp域和F2m域. Fp域,素数域,p为素数: F2m域:特征为2的有限域,称之为二元域或 ...

  8. 编译原理_P1002

    . 词法分析 1.1 词法记号及属性 词法记号.模式.词法单元 记号名 词法单元列举    模式的非形式描述 if if 字符i,f for for     字符f,o,r relation < ...

  9. Servlet中转发和重定向的路径问题以及表单提交路径问题

    一.请求转发与响应重定向的种类 有两种方式获得Servlet转发对象(RequestDispatcher):一种是通过HttpServletRequest的getRequestDispatcher() ...

  10. 十八、linux系统分区

    一.磁盘存储结构图:这里注意下,分区标有64字节,则每个分区有16字节,MBR引导分区有446字节,共有510字节,一个扇区有512字节,还有俩个字节是分区结束标识.比如隐藏文件等标识,都是这2个字节 ...