看得不明不白(我在下一篇中写了如何理解gather的用法)

gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]] # dim=1

二维tensor的gather操作

针对0轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]])
print("index = \n", index) #index是2维
print("index的形状: ",index.shape) #index形状是(1,4)

输出

index =
tensor([[0, 1, 2, 3]])
index的形状: torch.Size([1, 4])

分割线============

针对1轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]]).t()  #index是2维
print("index = \n", index) #index形状是(4,1)
print("index的形状: ",index.shape)

输出

index =
tensor([[0],
[1],
[2],
[3]])
index的形状: torch.Size([4, 1])

分割线===========

再来看看几个例子

注意index在以0轴和1轴为标准时的表达式是不一样的。

b.gather()中取0维时,输出的结果是行形式,取1维时,输出的结果是列形式。

  • b是一个 $ 3\times4 $ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]]) >>> index
tensor([[0, 1, 2]]) >>> b.gather(0,index) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index2 = t.LongTensor([[0,1,2]]).t() >>> b.gather(1,index2) #运行成功了
tensor([[ 0],
[ 5],
[10]]) >>> index3 = t.LongTensor([[0,1,2,3]]).t() >>> b.gather(1,index3) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
  • b是一个 $ 6\times6 $ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]]) >>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index) #运行成功了
tensor([[ 0, 7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()
>>> b.gather(1,index2) #运行成功了
tensor([[ 0],
[ 7],
[14],
[21],
[28],
[35]]) >>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。





与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)

根据StackOverflow上的问题修改代码如下:

输入

# 把两个对角线元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())

输出

tensor([[ 0.,  0.,  0.,  3.],
[ 0., 5., 6., 0.],
[ 0., 9., 10., 0.],
[12., 0., 0., 15.]])

关于Pytorch的二维tensor的gather和scatter_操作用法分析的更多相关文章

  1. 2017.11.17 C++系列---用malloc动态给c++二维数组的申请与释放操作

    方法一:利用二级指针申请一个二维数组. #include<stdio.h> #include<stdlib.h> int main() { int **a; //用二级指针动态 ...

  2. Pytorch学习笔记(二)——Tensor

    一.对Tensor的操作 从接口的角度讲,对Tensor的操作可以分为两类: (1)torch.function (2)tensor.function 比如torch.sum(a, b)实际上和a.s ...

  3. Javascript生成二维码(QR)

    网络上已经有非常多的二维码编码和解码工具和代码,很多都是服务器端的,也就是说需要一台服务器才能提供二维码的生成.本着对服务器性能的考虑,这种小事情都让服务器去做,感觉对不住服务器,尤其是对于大流量的网 ...

  4. 二维指针*(void **)的研究(uC/OS-II案例) 《转载》

    uC/OS-II内存管理函数内最难理解的部分就是二维指针,本文以图文并茂的方式对二维指针进行了详细分析与讲解.看完本文,相信对C里面指针的概念又会有进一步的认识. 一.OSMemCreate( ) 函 ...

  5. Stars(二维树状数组)

    Stars Time Limit: 5000/2000 MS (Java/Others) Memory Limit: 32768/65536 K (Java/Others) Total Submiss ...

  6. iOS原生实现二维码拉近放大

    http://www.cocoachina.com/ios/20180416/23033.html 2018-04-16 15:34 编辑: yyuuzhu 分类:iOS开发 来源:程序鹅 8 300 ...

  7. Java打印M图形(二维数组)——(九)

    对于平面图形输出集合图形与数字组合的,用二维数组.先在Excel表格中分析一下,找到简单的规律.二维数组的行数为行高,列数为最后一个数大小. 对于减小再增大再减小再增大的,可以用一个boolean标志 ...

  8. 混沌数学之二维logistic模型

    上一节讲了logistic混沌模型,这一节对其扩充一下讲二维 Logistic映射.它起着从一维到高维的衔接作用,对二维映射中混沌现象的研究有助于认识和预测更复杂的高维动力系统的性态.通过构造一次藕合 ...

  9. UVa 11297 Census (二维线段树)

    题意:给定上一个二维矩阵,有两种操作 第一种是修改 c x y val 把(x, y) 改成 val 第二种是查询 q x1 y1 x2 y2 查询这个矩形内的最大值和最小值. 析:二维线段树裸板. ...

随机推荐

  1. 学习php必须要了解的一些知识

    前言:每个人的成功都是用辛勤的劳动换来的 一.网络的基础知识 IP地址:Internet protocol address 指的是互联网协议地址,由二进制构成,(IPV4是32位的二进制),我们人为的 ...

  2. 查看java进程的所有信息

    查看java 进程下的所有信息 ll /proc/pid/fd ru:ll /proc/24047/fd

  3. linux 终端操作快捷键

    熟练使用快捷键可以很大的提高效率,以下列出一些常用的快捷键命令方便随时查阅 1. 移动光标 Ctrl + a 标移到行首.它在多数文本编辑器和 Mozilla 的 URL 字段内可以使用.Ctrl + ...

  4. You can add an index on a column that can have NULL values if you are using the MyISAM, InnoDB, or MEMORY storage engine.

    w https://dev.mysql.com/doc/refman/5.7/en/create-index.html MySQL :: MySQL 5.7 Reference Manual :: B ...

  5. echarts系列之动态加载数据

    1.echarts学习前言 最近接触到echarts,发现数据可视化真的是魅力无穷啊,各种变幻的曲线交错,以及‘曼妙’的动画效果真是让人如痴如醉! 下面就来一起欣赏她的美... “ ECharts是中 ...

  6. 使用 adb logcat 显示 Android 日志

    本文为转载.  地址:http://www.hanshuliang.com/?post=32 eclipse 自带的 LogCat 工具太垃圾了, 开始用 adb logcat 在终端查看日志; 1. ...

  7. Linux源码包安装和脚本安装

    能够先 vi INSTALL 看看安装过程. 1.源码包安装 2.脚本安装

  8. DOM 查找标签

    1.直接查找 document.getElementById // 根据ID获取一个标签 document.getElementsByClassName //根据class属性获取 document. ...

  9. 2. 安装 Kerberos

    2.1. 环境配置 安装kerberos前,要确保主机名可以被解析. 主机名 内网IP 角色 Vmw201 172.16.18.201 Master KDC Vmw202 172.16.18.202 ...

  10. Version 1.5 of the JVM is not suitable for this product.Version:1.6 or greater is required

    近期在公司涉及到了服务器等的扩展,smartfoxserver扩展使用的Eclipse,尽管没学过java.可是咱毕竟是C++起价的,其它语言看看也就会了,项目依然做着,近期看到某同学有一些java的 ...