1. 引言

  最近在刷开源的Pytorch版动手学深度学习,里面谈到几个高级选择函数,如index_select,masked_select,gather等。这些函数大多很容易理解,但是对于gather函数,确实有些难理解,官方文档开始也看得一脸懵,感觉不太直观。下面谈谈我对这几个函数的一些理解。

2. 维度的理解

  对于numpy和pytorch,其数组在做维度运算上刚开始可能会给人一种直观上的误解,以numpy求矩阵某个维度的最大值为例(pytorch的理解也是一样的)

import numpy as np
a = np.arange(1, 13).reshape(3, 4)
"""
result:
a = [[1, 2, 3, 4],
[5, 6, 7, 8,],
[9, 10, 11, 12]]
""" # 对a维度0求最大值
a.max(axis = 0)
"""
result:
[9, 10, 11, 12]
""" # 对a维度1求最大值
a.max(axis = 1)
"""
result:
[4, 8, 12]
"""

  如果对a矩阵在维度0上找最大值,根据我们直观上的经验应该是[4, 8, 12]。即从[1, 2, 3, 4]找到4,从[5, 6, 7, 8]找到8,从[9, 10, 11, 12]找到12。但是从上面结果来看,numpy运算却给了我们直观上认为是列最大值的结果[9, 10, 11, 12]。

  实际numpy(pytorch)运算应该理解为往给定的维度进行移动运算。还是以维度0为例,维度0上有3个向量,分别为[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往维度0移动,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素计算最大值,得到[5, 6, 7, 8],再和[9, 10, 11, 12]运算得到结果[9, 10, 11, 12]。

  另外,对于维度为3的数组,在numpy和pytorch中,应该把维度0理解为通道数,维度1和维度2才是对应高和宽。如果是3维数组对应着用于多输入通道和单输出通道的卷积核(维度为U x V x D),那么4维数组就对应着用于多输入通道和多输出通道的卷积核(维度为U x V x D x P),此时,维度0则为多通道卷积核数量的方向,维度1为通道数,维度2和3才是分别对应高和宽。

3. gather函数

pytorch和numpy中许多函数都涉及维度运算,gather也不例外,但是它相对于其他函数更难理解。依然先来看一个例子

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15]]
""" # 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]) # axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
""" # axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起来可能有点复杂,我们来一步步的分析它,先从gather维度为0开始讲起。

  1. a.gather(0, b)分为3个部分,a是需要被提取元素的矩阵,0代表的是提取的维度为0,b是提取元素的索引

    • 其中规定b和a是同维张量,即a是2维张量,b也必须是2维张量
  2. 0除了代表往维度0的方向提取元素外,还有一个特权---提取结果output可以在这个维度上的长度与a不同。打个比方,a现在的shape为(5, 3),那么提取结果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。
  3. 根据0的特权,导致了给定的b张量除了维度0外,其他的维度大小必须和a一样。其中张量b实际上包含以下两个信息
    • b可以利用除用于gather的维度(此处为维度0)外的维度来定位出唯一一个向量,也就是a[:, ?](三维度也是同理的,有a[:, ?1, ?2]),?的取值范围为a同维度的index。
    • 对于上述定位出的向量,通过b中的元素来定位提取向量中的哪一个元素。
    • 上面说得可能有点抽象,实际上b中的每个元素都能在a中提取出一个元素。举个具体点的例子,按照上面所说的,b[0, 0]可以提取a中的一个元素。对于b[0,0],除了维度0外,可以通过维度1来定位出唯一一个向量a[:, 0]。因为b[0, 0]的元素为0,即提取的是a[:, 0]的第0个元素---1,并将其作为output1[0, 0]的提取结果。

      下图给出了维度0和维度1,gather运算的图示

对于3维或者更高维度的张量gather的原理也是一样的

4. index_select函数

其他的高级选择函数都比较容易理解,这里简单的提一下。torch.index_select主要是根据传入的tensor来往给定的axis方向来选取张量

import torch
a = torch.arange(9).reshape(3, 3)
torch.index_select(a, 0, torch.tensor([0, 2]))
"""
result:
[[0, 1, 2],
[6, 7, 8]]
"""

5. masked_select函数

实际上就是通过掩码条件来选择元素,像torch.masked_select(x, x>0.5),实际上是和x[x>0.5]等价的,最后返回的是一维张量

import torch
a = torch.rand(5, 3) # 结果和a[a > 0.5]等价
torch.masked_select(a, a>0.5)

6. nonzero函数

找到非零元素的index

import torch
a = torch.eye(3)
torch.nonzero(a) """
result: 对应着非零元素的index
[[0, 0],
[1, 1],
[2, 2]]
"""

理解pytorch几个高级选择函数(如gather)的更多相关文章

  1. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  2. 关于Pytorch的二维tensor的gather和scatter_操作用法分析

    看得不明不白(我在下一篇中写了如何理解gather的用法) gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下: out[i][j] = input[index[i][j]] ...

  3. 理解PyTorch的自动微分机制

    参考Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works 非常好的文章,讲解的非 ...

  4. 理解pytorch中的softmax中的dim参数

    import torch import torch.nn.functional as F x1= torch.Tensor( [ [1,2,3,4],[1,3,4,5],[3,4,5,6]]) y11 ...

  5. [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)

    一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...

  6. 《深入理解Java虚拟机:JVM高级特性与最佳实践》【PDF】下载

    <深入理解Java虚拟机:JVM高级特性与最佳实践>[PDF]下载链接: https://u253469.pipipan.com/fs/253469-230062566 内容简介 作为一位 ...

  7. 什么是pytorch(1开始)(翻译)

    Deep Learning with PyTorch: A 60 Minute Blitz 作者: Soumith Chintala 部分翻译:me 本内容包含: 在高级层面理解pytorch的ten ...

  8. 万字综述,核心开发者全面解读PyTorch内部机制

    斯坦福大学博士生与 Facebook 人工智能研究所研究工程师 Edward Z. Yang 是 PyTorch 开源项目的核心开发者之一.他在 5 月 14 日的 PyTorch 纽约聚会上做了一个 ...

  9. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

随机推荐

  1. 分享一个登录页面(前端框架layui)-20200318

    效果图 对该页面的总结: 1.前端框架layui layui官网:https://www.layui.com/,下载之后,简单配置就可使用 2.layui模块引用与使用的方式 <script&g ...

  2. python练习 英文字符的鲁棒输入+数字的鲁棒输入

    鲁棒 = Robust 健壮 英文字符的鲁棒输入 描述 获得用户的任何可能输入,将其中的英文字符进行打印输出,程序不出现错误.‪‬‪‬‪‬‪‬‪‬‮‬‫‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‫‬‪‬‪‬‪ ...

  3. Istio安全-授权(实操三)

    Istio安全-授权 目录 Istio安全-授权 授权HTTP流量 为使用HTTP流量的负载配置访问控制 卸载 授权TCP流量 部署 配置TCP负载的访问控制 卸载 使用JWT进行授权 部署 使用有效 ...

  4. js中的各种常用方法(持续更新中。。。)

    我看到常用的就写上去,如果你们有,可以在评论上发表,我再把它补充到我的随笔中 some() var ages = [3, 10, 18, 20]; function checkAdult(age) { ...

  5. Visual Studio编译Core程序部署到linux

    一.背景 随着微软拥抱开源,推出Net Core框架,目前已经支持跨平台,能部署到Linux.MacOS.Windows等系统上. 下面我们就来分享一下Visual Studio编译好的代码部署到Li ...

  6. Codeforces Round #669 (Div. 2)A-C题解

    A. Ahahahahahahahaha 题目:http://codeforces.com/contest/1407/problem/A 题解:最多进行n/2的操作次数,我们统计这n个数中1的个数,是 ...

  7. Java里一个线程两次调用start()方法会出现什么情况

    Java的线程是不允许启动两次的,第二次调用必然会抛出IllegalThreadStateException,这是一种运行时异常,多次调用start被认为是编程错误. 如果业务需要线程run中的代码再 ...

  8. YOLOv4: Darknet 如何于 Docker 编译,及训练 COCO 子集

    YOLO 算法是非常著名的目标检测算法.从其全称 You Only Look Once: Unified, Real-Time Object Detection ,可以看出它的特性: Look Onc ...

  9. 5 分钟带你掌握 Makefile 分析

    摘要:Makefile是一个名为GNU-Make软件所需要的脚本文件,该脚本文件可以指导Make软件控制arm-gcc等工具链去编译工程文件最终得到可执行文件,几乎所有的Linux发行版都内置了GNU ...

  10. .NET 5.0 RC1 发布,离正式版发布仅剩两个版本

    原文:http://dwz.win/Qf8 作者:Richard 翻译:精致码农-王亮 说明:1. 本译文并不是完全逐句翻译的,存在部分语句我实在不知道如何翻译或组织就根据个人理解用自己的话表述了.2 ...