gather函数
gather(input, dim, index):根据 index,在 dim 维度上选取数据,输出的 size 与 index 一致
# input (Tensor) – 源张量
# dim (int) – 索引的轴
# index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
# out (Tensor, optional) – 目标张量
for 3D tensor:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=2
for 2D tensor:
out[i][j] = input[index[i][j]][j] # dim = 0
out[i][j] = input[i][index[i][j]] # dim = 1
import torch as t # 导入torch模块
c = t.arange(0, 60).view(3, 4, 5) # 定义tensor
print(c)
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
b = t.gather(c, 0, index)
print(b) 输出:
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49],
[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59]]])
报错:
Traceback (most recent call last):
File "E:/Release02/my_torch.py", line 14, in <module>
b = t.gather(c, 0, index)
RuntimeError: Size does not match at dimension 1 get 4 vs 3
(第1维尺寸不匹配)
将index调整为:
index = t.LongTensor([[[0, 1, 2, 0, 2], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]]])
则上文输出为:
tensor([[[ 0, 21, 42, 3, 44],
[ 5, 6, 7, 8, 9],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]],
[[20, 41, 42, 43, 44],
[ 5, 6, 7, 8, 9],
[50, 51, 52, 53, 54],
[35, 36, 37, 38, 39]],
[[20, 41, 42, 43, 44],
[ 5, 6, 7, 8, 9],
[50, 51, 52, 53, 54],
[35, 36, 37, 38, 39]]])
对于2D tensor 则无“index与tensor 的size一致”之要求,
这个要求在官方文档和其他博文、日志中均无提到
(可能是个坑吧丨可能是个坑吧丨可能是个坑吧)
eg:
代码(此部分来自https://www.yzlfxy.com/jiaocheng/python/337618.html):
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
输出:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3] 1 2
6 4
[torch.FloatTensor of size 2x2] 1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
官方文档:
torch.gather(input, dim, index, out=None) → Tensor Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor)-The source tensor
dim (int)-The axis along which to index
index (LongTensor)-The indices of elements to gather
out (Tensor, optional)-Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
以上,学习中遇到的问题,记录方便回顾,亦示他人以之勉坑
gather函数的更多相关文章
- [软件使用][matlab]最近经常用到的一些函数的意思,和用法
① cat(dim,A,B)按指定的维度,将A和B串联,dim是维度,比如1,2.1指列,2指行: ②numel(A),返回数组中,元素的个数 ③gpuArray(A),在gpu中产生一个数组A,一般 ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 理解pytorch几个高级选择函数(如gather)
目录 1. 引言 2. 维度的理解 3. gather函数 4. index_select函数 5. masked_select函数 6. nonzero函数 1. 引言 最近在刷开源的Pytor ...
- 从头造轮子:python3 asyncio之 gather (3)
前言 书接上文:,本文造第三个轮子,也是asyncio包里面非常常用的一个函数gather 一.知识准备 ● 相对于前两个函数,gather的使用频率更高,因为它支持多个协程任务"同时&qu ...
- R语言数据处理包dplyr、tidyr笔记
dplyr包是Hadley Wickham的新作,主要用于数据清洗和整理,该包专注dataframe数据格式,从而大幅提高了数据处理速度,并且提供了与其它数据库的接口:tidyr包的作者是Hadley ...
- Python中实现异步并发查询数据库
这周又填了一个以前挖下的坑. 这个博客系统使用Psycopy库实现与PostgreSQL数据库的通信.前期,只是泛泛地了解了一下SQL语言,然后就胡乱拼凑出这么一个简易博客系统. 10月份找到工作以后 ...
- tidyr包--数据处理包
tidyr包的作者是Hadley Wickham.这个包常跟dplyr结合使用.本文将介绍tidyr包中下述四个函数的用法: gather—宽数据转为长数据.类似于reshape2包中的melt函数 ...
- R----tidyr包介绍学习
tidyr包:reshape2的替代者,功能更纯粹 tidyr包的应用 tidyr主要提供了一个类似Excel中数据透视表(pivot table)的功能;gather和spread函数将数据在长格式 ...
- TensorFlow深度学习笔记 循环神经网络实践
转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnotes 欢迎star,有问题可以到Issue区讨论 官方教程地址 视频/字幕下载 加 ...
随机推荐
- 快速搭建网站信息库(小型Zoomeye)
前言:本来是不想重复造车轮的,网上资料有开源的fofa,和一些设计.有的架设太复杂了,好用东西不会用,整个毛线.还有的没有完整代码. 设计方案: 测试平台:windows 测试环境:php ...
- bugku ctf 逆向题
1.逆向入门 2.Easy_vb 直接找出来. 3.easy_re 4.游戏过关 摁着嗯着就出来了... 5.Timer{阿里ctf} apk文件,不会搞. 6.逆向入门 发现是base64,直接转图 ...
- 常问的MySQL面试题整理
char.varchar 的区别是什么? varchar是变长而char的长度是固定的.如果创建的列是固定大小的,你会得到更好的性能 truncate 和 delete 的区别是什么? delete ...
- Redis分布式锁的正确姿势
1. 核心代码: import redis.clients.jedis.Jedis; import java.util.Collections; /** * @Author: qijigui * @C ...
- linux--配置开发环境 --Apache篇
现在我的的linux服务器上一般都是使用:Apache 和 Nginx 这两种配置. 你现在安装好了,启动了,也无法通过你服务器绑定的网址访问你的网站. 这是你可以通过这个命令查看一下你的80端口: ...
- kubeadm 默认镜像配置问题引申
背景: 每次使用功能kubeadm的时候都需要提前准备好镜像,为什么自定义使用的镜像源呢? 在没有翻越围墙时 kubeadm init --kubernetes-version=v1.13.0 --p ...
- Windows 版本 Enterprise、Ultimate、Home、Professional
关于Windows 的安装光盘版本很多种,很多人不知道选择哪些. Ultimate 旗舰版,VISTA开始有了这个级别,是最全最高级的,一般程序开发的电脑,玩游戏的电脑,建议用它,不过对配置稍有一些要 ...
- Debugging Under Unix: gdb Tutorial (https://www.cs.cmu.edu/~gilpin/tutorial/)
//注释掉 #include <iostream.h> //替换为 #include <iostream> using namespace std; Contents Intr ...
- Navicat premium15安装破解教程
Navicat premium15安装破解教程 注意:安装之前请卸载干净navicat,不要覆盖安装 1.去官网下载Navicat premium15的安装包 官网地址:https://www.nav ...
- ps 和 top
ps 进程和线程的关系: (1)一个线程只能属于一个进程,而一个进程可以有多个线程,但至少有一个线程. (2)资源分配给进程,同一进程的所有线程共享该进程的所有资源. (3)处理机分给线程,即真正在处 ...