Tensor

自从张量(Tensor)计算这个概念出现后,神经网络的算法就可以看作是一系列的张量计算。所谓的张量,它原本是个数学概念,表示各种向量或者数值之间的关系。PyTorch的张量(torch.Tensor)表示的是N维矩阵与一维数组的关系。

torch.Tensor的使用方法和numpy很相似(https://pytorch.org/...tensor-tutorial-py),两者唯一的区别在于torch.Tensor可以使用GPU来计算,这就比用CPU的numpy要快很多。

张量计算的种类有很多,比如加法、乘法、矩阵相乘、矩阵转置等,这些计算被称为算子(Operator),它们是PyTorch的核心组件。

算子的backend一般是C/C++的拓展程序,PyTorch的backend是称为"ATen"的C/C++库,ATen是"A Tensor"的缩写。

Operator

PyTorch所有的Operator都定义在Declarations.cwrap和native_functions.yaml这两个文件中,前者定义了从Torch那继承来的legacy operator(aten/src/TH),后者定义的是native operator,是PyTorch的operator。

相比于用C++开发的native code,legacy code是在PyTorch编译时由gen.py根据Declarations.cwrap的内容动态生成的。因此,如果你想要trace这些code,需要先编译PyTorch。

legacy code的开发要比native code复杂得多。如果可以的话,建议你尽量避开它们。

MatMul

本文会以矩阵相乘--torch.matmul()为例来分析PyTorch算子的工作流程。

我在深入浅出全连接层(fully connected layer)中有讲在GPU层面是如何进行矩阵相乘的。Nvidia、AMD等公司提供了优化好的线性代数计算库--cuBLAS/rocBLAS/openBLAS,PyTorch只需要调用它们的API即可。

Figure 1是torch.matmul()在ATen中的function flow。可以看到,这个flow可不短,这主要是因为不同类型的tensor(2d or Nd, batched gemm or not,with or without bias,cuda or cpu)的操作也不尽相同。

at::matmul()主要负责将Tensor转换成cuBLAS需要的格式。前面说过,Tensor可以是N维矩阵,如果tensor A是3d矩阵,tensor B是2d矩阵,就需要先将3d转成2d;如果它们都是>=3d的矩阵,就要考虑batched matmul的情况;如果bias=True,后续就应该交给at::addmm()来处理;总之,matmul要考虑的事情比想象中要多。

除此之外,不同的dtype、device和layout需要调用不同的操作函数,这部分工作交由c10::dispatcher来完成。

Dispatcher

dispatcher主要用于动态调用dtype、device以及layout等方法函数。用过numpy的都知道,np.array()的数据类型有:float32, float16,int8,int32,.... 如果你了解C++就会知道,这类程序最适合用模板(template)来实现。

很遗憾,由于ATen有一部分operator是用C语言写的(从Torch继承过来),不支持模板功能,因此,就需要dispatcher这样的动态调度器。

类似地,PyTorch的tensor不仅可以运行在GPU上,还可以跑在CPU、mkldnn和xla等设备,Figure 1中的dispatcher4就根据tensor的device调用了mm的GPU实现。

layout是指tensor中元素的排布。一般来说,矩阵的排布都是紧凑型的,也就是strided layout。而那些有着大量0的稀疏矩阵,相应地就是sparse layout。

Figure 2是strided layout的演示实例,这里创建了一个2行2列的矩阵a,它的数据实际存放在一维数组(a.storage)里,2行2列只是这个数组的视图。

stride充当了从数组到视图的桥梁,比如,要打印第2行第2列的元素时,可以通过公式:\(1 * stride(0) + 1 * stride(1)\)来计算该元素在数组中的索引。

除了dtype、device、layout之外,dispatcher还可以用来调用legacy operator。比如说addmm这个operator,它的GPU实现就是通过dispatcher来跳转到legacy::cuda::_th_addmm。

END

到此,就完成了对PyTorch算子的学习。如果你要学习其他算子,可以先从aten/src/ATen/native目录的相关函数入手,从native_functions.yaml中找到dispatch目标函数,详情可以参考Figure 1。


更多精彩文章,欢迎扫码关注下方的公众号, 并访问我的简书博客:https://www.jianshu.com/u/c0fe8671254e

欢迎转发至朋友圈,工作号转载请后台留言申请授权~

深入浅出PyTorch(算子篇)的更多相关文章

  1. Ascend Pytorch算子功能验证

    Ascend Pytorch算子功能验证 编写测试用例 以add算子为例,测试脚本文件命名为:add_testcase.py.以下示例仅为一个简单的用例实现,具体算子的实现,需要根据算子定义进行完整的 ...

  2. Ascend Pytorch算子适配层开发

    Ascend Pytorch算子适配层开发 适配方法 找到和PyTorch算子功能对应的NPU TBE算子,根据算子功能计算出输出Tensor的size,再根据TBE算子原型构造对应的input/ou ...

  3. Spark算子篇 --Spark算子之aggregateByKey详解

    一.基本介绍 rdd.aggregateByKey(3, seqFunc, combFunc) 其中第一个函数是初始值 3代表每次分完组之后的每个组的初始值. seqFunc代表combine的聚合逻 ...

  4. 深入浅出 RPC - 深入篇

    <深入篇>我们主要围绕 RPC 的功能目标和实现考量去展开,一个基本的 RPC 框架应该提供什么功能,满足什么要求以及如何去实现它? RPC 功能目标 RPC 的主要功能目标是让构建分布式 ...

  5. (转)深入浅出 RPC - 深入篇

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/mindfloating/article/details/39474123 <深入篇>我们 ...

  6. 深入浅出RPC——深入篇(转载)

    本文转载自这里是原文 <深入篇>我们主要围绕 RPC 的功能目标和实现考量去展开,一个基本的 RPC 框架应该提供什么功能,满足什么要求以及如何去实现它? RPC 功能目标 RPC的主要功 ...

  7. Spark算子篇 --Spark算子之combineByKey详解

    一.概念 rdd.combineByKey(lambda x:"%d_" %x, lambda a,b:"%s@%s" %(a,b), lambda a,b:& ...

  8. 时序数据库深入浅出之存储篇——本质LSMtree,同时 metric(比如温度)+tags 分片

    什么是时序数据库 先来介绍什么是时序数据.时序数据是基于时间的一系列的数据.在有时间的坐标中将这些数据点连成线,往过去看可以做成多纬度报表,揭示其趋势性.规律性.异常性:往未来看可以做大数据分析,机器 ...

  9. spark算子篇-repartition and coalesce

    我们知道 RDD 是分区的,但有时候我们需要重新设置分区数量,增大还是减少需要结合实际场景,还有可以通过设置 RDD 分区数来指定生成的文件的数量 重新分区有两种方法:repartition and ...

随机推荐

  1. 从Point类继承的Circle类 代码参考

    #include <iostream> #include <cstring> using namespace std; class Point { private: int x ...

  2. Redis详解(十二)------ 缓存穿透、缓存击穿、缓存雪崩

    本篇博客我们来介绍Redis使用过程中需要注意的三种问题:缓存穿透.缓存击穿.缓存雪崩. 1.缓存穿透 一.概念 缓存穿透:缓存和数据库中都没有的数据,可用户还是源源不断的发起请求,导致每次请求都会到 ...

  3. Take advantage of Checkra1n to Jailbreak iDevice for App analysis

    An unpatchable bootrom exploit called "checkm8" works on all iDevices up until the iPhone ...

  4. JAVASE(十八) 反射: Class的获取、ClassLoader、反射的应用、动态代理

    个人博客网:https://wushaopei.github.io/    (你想要这里多有) 1.反射(JAVA Reflection)的理解 1.1 什么是反射(JAVA Reflection) ...

  5. Vue父子组件传值以及父调子方法、子调父方法

    稍微总结了一下Vue中父子间传值以及相互调方法的问题,非常基础.希望可以帮到你!先来个最常用的,直接上代码: 1.父传值给子组件 父组件: <template> <div> & ...

  6. Java实现 LeetCode 331 验证二叉树的前序序列化

    331. 验证二叉树的前序序列化 序列化二叉树的一种方法是使用前序遍历.当我们遇到一个非空节点时,我们可以记录下这个节点的值.如果它是一个空节点,我们可以使用一个标记值记录,例如 #. _9_ / \ ...

  7. java实现第五届蓝桥杯LOG大侠

    LOG大侠 atm参加了速算训练班,经过刻苦修炼,对以2为底的对数算得飞快,人称Log大侠. 一天,Log大侠的好友 drd 有一些整数序列需要变换,Log大侠正好施展法力- 变换的规则是: 对其某个 ...

  8. vi命令总结

    VI常用技巧 ​ VI命令可以说是Unix/Linux世界里最常用的编辑文件的命令了,但是因为它的命令集众多,很多人都不习惯使用它,其实您只需要掌握基本命令,然后加以灵活运用,就会发现它的优势,并会逐 ...

  9. 基于Nginx实现访问控制、连接限制

    0 前言 Nginx自带的模块支持对并发请求数进行限制, 还有对请求来源进行限制.可以用来防止DDOS攻击.阅读本文须知道nginx的配置文件结构和语法. 1. 默认配置语法 nginx.conf作为 ...

  10. javascript内置函数提供的显式绑定

    内置函数提供的显式绑定 最近在开发中遇到使用arr.map(module.fun) 这样的写法时(在一个模块调用了另外一个模块的方法), 造成了函数中this丢失的问题, 显示为undefined, ...