Pytorch: repeat, repeat_interleave, tile的用法
https://zhuanlan.zhihu.com/p/474153365
torch.repeat
使张量沿着某个维度进行复制, 并且不仅可以复制张量,也可以拓展张量的维度:
import torch
x = torch.randn(2, 4)
# 1. 沿着某个维度复制
x.repeat(1, 1).size() # torch.Size([2, 4])
x.repeat(2, 1).size() # torch.Size([4, 4])
x.repeat(1, 2).size() # torch.Size([2, 8])
# 2. 不仅可以复制维度, 还可以拓展维度
x.repeat(1, 1, 1).size() # torch.Size([1, 2, 4])
x.repeat(2, 1, 1).size() # torch.Size([2, 2, 4])
x.repeat(1, 1, 1, 1).size() # torch.Size([1, 1, 2, 4])
# 3. repeat中传入的参数不可以少于x的维度
x.repeat(1) # 报错
torch.repeat_interleave
torch.repeat_interleave的行为与numpy.repeat类似,但是和torch.repeat不同,这边还是以代码为例:
import torch
x = torch.randn(2, 2)
print(x)
>>> tensor([[ 0.4332, 0.1172],
[ 0.8808, -1.7127]])
print(x.repeat(2, 1))
>>> tensor([[ 0.4332, 0.1172],
[ 0.8808, -1.7127],
[ 0.4332, 0.1172],
[ 0.8808, -1.7127]])
print(x.repeat_interleave(2, dim=0))
>>> tensor([[ 0.4332, 0.1172],
[ 0.4332, 0.1172],
[ 0.8808, -1.7127],
[ 0.8808, -1.7127]])
print(x.repeat_interleave(2, dim=1))
>>> tensor([[ 0.4332, 0.4332, 0.1172, 0.1172],
[ 0.8808, 0.8808, -1.7127, -1.7127]])
# 如果不传dim参数, 则默认复制后拉平
print(x.repeat_interleave(2))
>>> tensor([ 0.4332, 0.4332, 0.1172, 0.1172, 0.8808, 0.8808, -1.7127, -1.7127])
从这个代码可以看出来torch.repeat更像是把tensor作为一个整体进行复制, 而torch.repeat_interleave更是针对tensor里的每个元素进行复制,并且torch.repeat_interleave可以通过传入一个一维的torch.Tensor来指定每个元素复制的次数
import torch
x = torch.tensor([[1, 2], [3, 4]])
result = torch.repeat_interleave(x, torch.tensor([1, 3]), dim=0)
print(result)
>>> tensor([[1, 2],
[3, 4],
[3, 4],
[3, 4]])
torch.tile
torch.tile函数也是元素复制的一个函数, 但是在传参上和torch.repeat不同,但是也是以input为一个整体进行复制, torch.tile如果只传入一个参数的话, 默认是沿着行进行复制
import torch
x = torch.tensor([[1, 2], [3, 4]])
# 只传入一个参数
print(x.tile((2, )))
>>> tensor([[1, 2, 1, 2],
[3, 4, 3, 4]])
print(x.repeat(1, 2))
>>> tensor([[1, 2, 1, 2],
[3, 4, 3, 4]])
torch.tile传入一个元组的话, 表示(行复制次数, 列复制次数)
import torch
x = torch.tensor([[1, 2], [3, 4]])
print(x.tile((2, 2)))
>>> tensor([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
print(x.repeat(2, 2))
>>> tensor([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
当传入的参数少于需要复制的元素的维度时, 如果一个tensor的形状为(2, 2, 2),传入tile中的参数为(2, 2)时, 会默认表示为(1, 2, 2)
import torch
x = torch.randn(2, 2, 2)
print(x)
>>> tensor([[[ 0.8517, 0.8721],
[-1.1591, -0.2000]],
[[ 0.3888, -0.8365],
[-1.6383, -0.1539]]])
print(x.tile((2, 2)))
>>> tensor([[[ 0.8517, 0.8721, 0.8517, 0.8721],
[-1.1591, -0.2000, -1.1591, -0.2000],
[ 0.8517, 0.8721, 0.8517, 0.8721],
[-1.1591, -0.2000, -1.1591, -0.2000]],
[[ 0.3888, -0.8365, 0.3888, -0.8365],
[-1.6383, -0.1539, -1.6383, -0.1539],
[ 0.3888, -0.8365, 0.3888, -0.8365],
[-1.6383, -0.1539, -1.6383, -0.1539]]])
当传入的参数多于需要复制的元素维度时,会拓展维度
import torch
x = torch.randn(2, 2)
print(x)
>>> tensor([[ 1.1165, -0.5559],
[-0.6341, 0.5215]])
print(x.tile((2, 2, 2)))
>>> tensor([[[ 1.1165, -0.5559, 1.1165, -0.5559],
[-0.6341, 0.5215, -0.6341, 0.5215],
[ 1.1165, -0.5559, 1.1165, -0.5559],
[-0.6341, 0.5215, -0.6341, 0.5215]],
[[ 1.1165, -0.5559, 1.1165, -0.5559],
[-0.6341, 0.5215, -0.6341, 0.5215],
[ 1.1165, -0.5559, 1.1165, -0.5559],
[-0.6341, 0.5215, -0.6341, 0.5215]]])
使用tile和reshape代替repeat_interleave
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)
y = torch.repeat_interleave(x, repeats=3, dim=0)
print(y)
>>> tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
# 直接使用tile, 无法得到类似的结果
z = torch.tile(x, (3, ))
print(z)
>>> tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6, 4, 5, 6]])
z = torch.tile(x, (3, 1))
print(z)
>>> tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
# 需要使用 tile + reshape 才可以得到类似的结果
z = torch.tile(x, (3, ))
print(z.shape) # (2, 9)
print(z.reshape(6, 3)) # 得到了和y一样的输出
>>> tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
Pytorch: repeat, repeat_interleave, tile的用法的更多相关文章
- Pytorch中nn.Conv2d的用法
Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...
- numpy数组扩展函数repeat和tile用法
numpy.repeat(a, repeats, axis=None) >>> a = np.arange(3) >>> a array([0, 1, 2]) &g ...
- python tile函数用法
tile函数位于python模块 numpy.lib.shape_base中,他的功能是重复某个数组.比如tile(A,n),功能是将数组A重复n次,构成一个新的数组,我们还是使用具体的例子来说明问题 ...
- Python-Numpy的tile函数用法
1.函数的定义与说明 函数格式tile(A,reps) A和reps都是array_like A的类型众多,几乎所有类型都可以:array, list, tuple, dict, matrix以及基本 ...
- [PyTorch]PyTorch中反卷积的用法
文章来源:https://www.jianshu.com/p/01577e86e506 pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下: class torch.nn.Conv2d ...
- python3中numpy函数tile的用法
tile函数位于python模块 numpy.lib.shape_base中,他的功能是重复某个数组.比如tile(A,n),功能是将数组A重复n次,构成一个新的数组,我们还是使用具体的例子来说明问题 ...
- numpy中tile的用法
a=arange(1,3) #a的结果是: array([1,2]) 1,当 tile(a,1) 时: tile(a,1) #结果是 array([1,2]) tile(a,2) #结果是 array ...
- pytorch实现yolov3(4) 非极大值抑制nms
在上一篇里我们实现了forward函数.得到了prediction.此时预测出了特别多的box以及各种class probability,现在我们要从中过滤出我们最终的预测box. 理解了yolov3 ...
- Python numpy中矩阵的用法总结
关于Python Numpy库基础知识请参考博文:https://www.cnblogs.com/wj-1314/p/9722794.html Python矩阵的基本用法 mat()函数将目标数据的类 ...
随机推荐
- UiPath邮件自动化
在UiPath中下载Outlook电子邮件附件Outlook电子邮件自动化教程UiPathRPAhttps://www.bilibili.com/video/BV1oK411L72T 在UiPath中 ...
- RPA应用场景-营业收入核对
场景概述营业收入核对 所涉系统名称 SAP ,Excel,门店业务系统 人工操作(时间/次) 4 小时 所涉人工数量 6 操作频率每日 场景流程 1.每日13点起进入SAP查询前一日营业收入记账情况: ...
- Python:socket编程教程
ocket是基于C/S架构的,也就是说进行socket网络编程,通常需要编写两个py文件,一个服务端,一个客户端. 首先,导入Python中的socket模块: import socket Pytho ...
- Tomcat深入浅出——Servlet(二)
一.Servlet简介 Servlet类最终开发步骤: 第一步:编写一个Servlet类,直接继承HttpServlet 第二步:重写doGet方法或者doPost方法,重写哪个我说的算! 第三步:将 ...
- labview从入门到出家5(进阶篇)--程序调试以及labview函数库的运用
跟了前面几章的操作流程,相信大家对labview有了一定的认识.其实只要了解了labview的编程思路,再熟悉地运用各个变量,函数以及属性,那么我们就可以打开labview的大门了.跟其他编程语言一样 ...
- JavaScript基本知识点——带你逐步解开JS的神秘面纱
JavaScript基本知识点--带你逐步解开JS的神秘面纱 在我们前面的文章中已经深入学了HTML和CSS,在网页设计中我们已经有能力完成一个美观的网页框架 但仅仅是网页框架不足以展现出网页的魅力, ...
- identity server4 授权成功页面跳转时遇到错误:Exception: Correlation failed. Unknown location的解决方法
一.异常信息描述 错误信息,看到这个页面是否耳熟能详担又不知道怎么解决 ,坑死个人不偿命,,,,,,,, 二.处理方法 1.在web项目中增加类SameSiteCookiesServiceCollec ...
- 常见加密算法C#实现(一)
前言:最近项目中需要用到字符串加解密,遂研究了一波,发现密码学真的是博大精深,好多算法的设计都相当巧妙,学到了不少东西,在这里做个小小的总结,方便后续查阅. 文中关键词: 明文(P,Plaintext ...
- [spring]spring详细总结
spring 1.spring简介 Spring框架是一个开源的应用程序框架,是针对bean的生命周期进行管理的轻量级容器. Spring解决了开发者在J2EE开发中遇到的许多常见的问题,提供了功能强 ...
- if条件控制语句和switch语句
if条件控制语句(判断范围,在一定区间内容进行判断) if 如果(第一个条件) else if 如果(第二个条件 可以无限加) else 否则(只能有一个 上面都不满足的情况下进入) if和else ...