gather(input, dim, index):根据  index,在  dim  维度上选取数据,输出的  size  与  index  一致

# input (Tensor) – 源张量

# dim (int) – 索引的轴

# index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)

# out (Tensor, optional) – 目标张量

for 3D tensor:

out[i][j][k] = tensor[index[i][j][k]][j][k]   # dim=0

out[i][j][k] = tensor[i][index[i][j][k]][k]   # dim=1

out[i][j][k] = tensor[i][j][index[i][j][k]]   # dim=2

for 2D tensor:

out[i][j] = input[index[i][j]][j]  # dim = 0

out[i][j] = input[i][index[i][j]]  # dim = 1

import torch as t  # 导入torch模块
c = t.arange(0, 60).view(3, 4, 5) # 定义tensor
print(c)
index = torch.LongTensor([[[0,1,2,0,2],
                [0,0,0,0,0],
                [1,1,1,1,1]],
                [[1,2,2,2,2],
                 [0,0,0,0,0],
                [2,2,2,2,2]]])
b = t.gather(c, 0, index)
print(b) 输出:

tensor([[[ 0, 1, 2, 3, 4],
             [ 5, 6, 7, 8, 9],
             [10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]],

[[20, 21, 22, 23, 24],
             [25, 26, 27, 28, 29],
             [30, 31, 32, 33, 34],
             [35, 36, 37, 38, 39]],

[[40, 41, 42, 43, 44],
             [45, 46, 47, 48, 49],
             [50, 51, 52, 53, 54],
             [55, 56, 57, 58, 59]]])

报错:

Traceback (most recent call last):
File "E:/Release02/my_torch.py", line 14, in <module>
b = t.gather(c, 0, index)
RuntimeError: Size does not match at dimension 1 get 4 vs 3

(第1维尺寸不匹配)

将index调整为:

index = t.LongTensor([[[0, 1, 2, 0, 2], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [1, 1, 1, 1, 1]]])
则上文输出为:

tensor([[[ 0, 21, 42, 3, 44],
             [ 5, 6, 7, 8, 9],
             [30, 31, 32, 33, 34],
             [35, 36, 37, 38, 39]],

[[20, 41, 42, 43, 44],
             [ 5, 6, 7, 8, 9],
             [50, 51, 52, 53, 54],
             [35, 36, 37, 38, 39]],

[[20, 41, 42, 43, 44],
             [ 5, 6, 7, 8, 9],
             [50, 51, 52, 53, 54],
             [35, 36, 37, 38, 39]]])

对于2D tensor 则无“index与tensor 的size一致”之要求,

这个要求在官方文档和其他博文、日志中均无提到

可能是个坑吧丨可能是个坑吧丨可能是个坑吧

eg:

代码(此部分来自https://www.yzlfxy.com/jiaocheng/python/337618.html):

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

输出:

 1 2 3
4 5 6
[torch.FloatTensor of size 2x3] 1 2
6 4
[torch.FloatTensor of size 2x2] 1 5 6
1 2 3
[torch.FloatTensor of size 2x3]

官方文档:
torch.gather(input, dim, index, out=None) → Tensor

 Gathers values along an axis specified by dim.

 For a 3-D tensor the output is specified by:

 out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor)-The source tensor
dim (int)-The axis along which to index
index (LongTensor)-The indices of elements to gather
out (Tensor, optional)-Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]

以上,学习中遇到的问题,记录方便回顾,亦示他人以之勉坑

												

gather函数的更多相关文章

  1. [软件使用][matlab]最近经常用到的一些函数的意思,和用法

    ① cat(dim,A,B)按指定的维度,将A和B串联,dim是维度,比如1,2.1指列,2指行: ②numel(A),返回数组中,元素的个数 ③gpuArray(A),在gpu中产生一个数组A,一般 ...

  2. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  3. 理解pytorch几个高级选择函数(如gather)

    目录 1. 引言 2. 维度的理解 3. gather函数 4. index_select函数 5. masked_select函数 6. nonzero函数 1. 引言   最近在刷开源的Pytor ...

  4. 从头造轮子:python3 asyncio之 gather (3)

    前言 书接上文:,本文造第三个轮子,也是asyncio包里面非常常用的一个函数gather 一.知识准备 ● 相对于前两个函数,gather的使用频率更高,因为它支持多个协程任务"同时&qu ...

  5. R语言数据处理包dplyr、tidyr笔记

    dplyr包是Hadley Wickham的新作,主要用于数据清洗和整理,该包专注dataframe数据格式,从而大幅提高了数据处理速度,并且提供了与其它数据库的接口:tidyr包的作者是Hadley ...

  6. Python中实现异步并发查询数据库

    这周又填了一个以前挖下的坑. 这个博客系统使用Psycopy库实现与PostgreSQL数据库的通信.前期,只是泛泛地了解了一下SQL语言,然后就胡乱拼凑出这么一个简易博客系统. 10月份找到工作以后 ...

  7. tidyr包--数据处理包

    tidyr包的作者是Hadley Wickham.这个包常跟dplyr结合使用.本文将介绍tidyr包中下述四个函数的用法: gather—宽数据转为长数据.类似于reshape2包中的melt函数 ...

  8. R----tidyr包介绍学习

    tidyr包:reshape2的替代者,功能更纯粹 tidyr包的应用 tidyr主要提供了一个类似Excel中数据透视表(pivot table)的功能;gather和spread函数将数据在长格式 ...

  9. TensorFlow深度学习笔记 循环神经网络实践

    转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnotes 欢迎star,有问题可以到Issue区讨论 官方教程地址 视频/字幕下载 加 ...

随机推荐

  1. 详解 缓冲区(Buffer 抽象类)

    在本篇博文中,本人主要讲解NIO 的两个核心点 -- 缓冲区(Buffer) 和 通道 (Channel)之一的 缓冲区(Buffer), 有关NIO流的其他知识点请观看本人博文<详解 NIO流 ...

  2. BUUOJ [WUSTCTF2020]朴实无华

    [WUSTCTF2020]朴实无华 复现了武科大的一道题/// 进入界面 一个hack me 好吧,直接看看有没有robot.txt 哦豁,还真有 好吧 fAke_f1agggg.php 看了里面,然 ...

  3. 详解PHP反序列化中的字符逃逸

    首发先知社区,https://xz.aliyun.com/t/6718/ PHP 反序列化字符逃逸 下述所有测试均在 php 7.1.13 nts 下完成 先说几个特性,PHP 在反序列化时,对类中不 ...

  4. 初探Redis-基础类型String

    Redis存在五种基础类型:字符串(String).队列(List).哈希(Hash).集合(Set).有序集合(Sorted Set).String的出镜率算是最高的.本次列举出String的常用操 ...

  5. Spring Cloud Gateway+Nacos,yml+properties两种配置文件方式搭建网关服务

    写在前面 网关的作用不在此赘述,举个最常用的例子,我们搭建了微服务,前端调用各服务接口时,由于各服务接口不一样,如果让前端同事分别调用,前端同事会疯的.而网关就可以解决这个问题,网关屏蔽了各业务服务的 ...

  6. 每天都在用,但你知道 Tomcat 的线程池有多努力吗?

    这是why的第 45 篇原创文章.说点不一样的线程池执行策略和线程拒绝策略,探讨怎么让线程池先用完最大线程池再把任务放到队列中. 荒腔走板 大家好,我是 why,一个四川程序猿,成都好男人. 先是本号 ...

  7. vim的常用指令

    vim的常用指令如下: 光标运动: h,j , k, l (上/下/左/右) 删除字符: x 删除行 : dd 模式退出 : Esc,Insert(或者i) 退出编辑器 : q 强制退出不保存: q! ...

  8. python实现线性回归之简单回归

    代码来源:https://github.com/eriklindernoren/ML-From-Scratch 首先定义一个基本的回归类,作为各种回归方法的基类: class Regression(o ...

  9. 如何调试 Inno Setup

    从命令行运行安装包,并加上 /log=filename

  10. 防cc攻击利器之Httpgrard

    一.httpgrard介绍 HttpGuard是基于openresty,以lua脚本语言开发的防cc攻击软件.而openresty是集成了高性能web服务器Nginx,以及一系列的Nginx模块,这其 ...