在Tensorflow、Numpy和PyTorch中都提供了使用einsum的api,einsum是一种能够简洁表示点积、外积、转置、矩阵-向量乘法、矩阵-矩阵乘法等运算的领域特定语言。在Tensorflow等计算框架中使用einsum,操作矩阵运算时可以免于记忆和使用特定的函数,并且使得代码简洁,高效。

如对矩阵\(A\in \mathbb{R}^{I×K}​\)和矩阵\(B\in \mathbb{R}^{K×J}​\)做矩阵乘,然后对列求和,最终得到向量\(c\in \mathbb{R}^J​\),即:

\[\mathbb{R}^{I×K}\bigotimes \mathbb{R}^{K×J}\to \mathbb{R}^{I×J}\to \mathbb{R}^{J}
\]

使用爱因斯坦求和约定表示为:

\[c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj}
\]

在Tensorflow、Numpy和PyTorch中对应的einsum字符串为:

ik,kj->j

在上面的字符串中,隐式地省略了重复的下标\(k\),表示在该维度矩阵乘;另外输出中未指明下标\(i\),表示在该维度累加。

Numpy、PyTorch和Tensorflow中的einsum

einsum在Numpy中的实现为np.einsum,在PyTorch中的实现为torch.einsum,在Tensorflow中的实现为tf.einsum,均使用同样的函数签名einsum(equation,operands),其中,equation传入爱因斯坦求和约定的字符串,而operands则是张量序列。在Numpy、Tensorflow中是变长参数列表,而在PyTorch中是列表。上述例子中,在Tensorflow中可写作:

tf.einsum('ik,kj->j',mat1,mat2)

其中,mat1、mat2为执行该运算的两个张量。注意:这里的(i,j,k)的命名是任意的,但在一个表达式中要一致。

PyTorch和Tensorflow像Numpy支持einsum的好处之一就是,einsum可以用于深度网络架构的任意计算图,并且可以反向传播。在Numpy和Tensorflow中的调用格式如下:

\[result=\mathop{einsum}('\square \square, \square \square \square,\square \square\to \square \square',arg1,arg2,arg3)
\]

其中,\(\square\)是占位符,表示张量维度;arg1,arg3是矩阵,arg2是三阶张量,运算结果是矩阵。注意:einsum处理可变数量的输入。上面例子中,einsum制定了三个参数的操作,但同样可以操作一个参数、两个参数和三个参数及以上的操作。

典型的einsum表达式

前置知识

  • 内积

    又称点积、点乘,对应位置数字相乘,结果是一个标量,有见向量内积和矩阵内积等。

    向量\(\vec a\)和向量\(\vec b\)的内积:

    \[\vec a=[a_1,a_2,...,a_n]\\
    \vec b=[b_1,b_2,...,b_n]\\
    \vec a\cdot \vec b^T=a_1b_1+a_2b_2+...+a_nb_n
    \]

    内积几何意义:

    \[\vec a \cdot \vec b^T=|\vec a||\vec b|\mathop{cos}\theta
    \]

  • 外积

    又称叉乘、叉积、向量积,行向量矩阵乘列向量,结果是二阶张量。注意到:张量的外积作为张量积的同义词。外积是一种特殊的克罗内克积。

    向量\(\vec a\)和向量\(\vec b\)的外积:

    \[\begin{bmatrix}
    b_1
    \\b_2
    \\ b_3
    \\ b_4
    \end{bmatrix}\bigotimes[a_1,a_2,a_3]=\begin{bmatrix}
    a_1b_1 & a_2b_1 & a_3b_1 \\
    a_1b_2 & a_2b_2 & a_3b_2 \\
    a_1b_3 & a_2b_3 & a_3b_3 \\
    a_1b_4 & a_2b_4 & a_3b_4 \\
    \end{bmatrix}
    \]

    外积的几何意义:

    \[\vec a=(x_1,y_1,z_1)\\
    \vec b=(x_2,y_2,z_2)\\
    \vec a\bigotimes\vec b=\begin{vmatrix}
    i & j & k\\
    x_1 & y_1 & z_1\\
    x_2 & y_2 & z_2
    \end{vmatrix}=(y_1z_2-y_2z_1)\vec i-(x_1z_2-x_2z_1)\vec j+(x_1y_2-x_2y_1)\vec k
    \]

    其中,

    \[\vec i=(1,0,0)\\
    \vec j=(0,1,0)\\
    \vec k=(0,0,1)
    \]

由于PyTorch可以实时输出运算结果,以PyTorch使用einsum表达式为例。

  • 矩阵转置

    \[B_{ji}=A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->ji',[a])
    >>>tensor([[0, 3],
    [1, 4],
    [2, 5]])
  • 求和

    \[b=\sum_{i}\sum_{j}A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->',[a])
    >>>tensor(15)
  • 列求和(列维度不变,行维度消失)

    \[b_j=\sum_iA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->j',[a])
    >>>tensor([ 3.,  5.,  7.])
  • 列求和(列维度不变,行维度消失)

    \[b_i=\sum_jA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->i', [a])
    >>>tensor([  3.,  12.])
  • 矩阵-向量相乘

    \[c_i=\sum_k A_{ik}b_k
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ik,k->i',[a,b])
    >>>tensor([  5.,  14.])
  • 矩阵-矩阵乘法

    \[C_{ij}=\sum_{k}A_{ik}B_{kj}
    \]

    a=torch.arange(6).reshape(2,3)
    b=torch.arange(15).reshape(3,5)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]]) >>>tensor([[ 0, 1, 2, 3, 4],
    [ 5, 6, 7, 8, 9],
    [10, 11, 12, 13, 14]])
    torch.einsum('ik,kj->ij',[a,b])
    >>>tensor([[ 25,  28,  31,  34,  37],
    [ 70, 82, 94, 106, 118]])
  • 点积

    • 向量

      \[c=\sum_i a_i b_i
      \]

      a=torch.arange(3)
      b=torch.arange(3,6)
      >>>tensor([0, 1, 2])
      >>>tensor([3, 4, 5])
      torch.einsum('i,i->',[a,b])
      >>>tensor(14.)
    • 矩阵

      \[c=\sum_i\sum_j A_{ij}B_{ij}
      \]

      a=torch.arange(6).reshape(2,3)
      b=torch.arange(6,12).reshape(2,3)
      >>>tensor([[0, 1, 2],
      [3, 4, 5]]) >>>tensor([[ 6, 7, 8],
      [ 9, 10, 11]])
      torch.einsum('ij,ij->',[a,b])
      >>>tensor(145.)
  • 外积

    \[C_{ij}=a_i b_j
    \]

    a=torch.arange(3)
    b=torch.arange(3,7)
    >>>tensor([0, 1, 2])
    >>>tensor([3, 4, 5, 6])
    torch.einsum('i,j->ij',[a,b])
    >>>tensor([[  0.,   0.,   0.,   0.],
    [ 3., 4., 5., 6.],
    [ 6., 8., 10., 12.]])
  • batch矩阵乘

    \[C_{ijl}=\sum_{k}A_{ijk}B_{ikl}
    \]

    a=torch.randn(3,2,5)
    b=torch.randn(3,5,3)
    >>>tensor([[[-1.4131e+00,  3.8372e-02,  1.2436e+00,  5.4757e-01,  2.9478e-01],
    [ 1.3314e+00, 4.4003e-01, 2.3410e-01, -5.3948e-01, -9.9714e-01]], [[-4.6552e-01, 5.4318e-01, 2.1284e+00, 9.5029e-01, -8.2193e-01],
    [ 7.0617e-01, 9.8252e-01, -1.4406e+00, 1.0071e+00, 5.9477e-01]], [[-1.0482e+00, 4.7110e-02, 1.0014e+00, -6.0593e-01, -3.2076e-01],
    [ 6.6210e-01, 3.7603e-01, 1.0198e+00, 4.6591e-01, -7.0637e-04]]]) >>>tensor([[[-2.1797e-01, 3.1329e-04, 4.3139e-01],
    [-1.0621e+00, -6.0904e-01, -4.6225e-01],
    [ 8.5050e-01, -5.8867e-01, 4.8824e-01],
    [ 2.8561e-01, 2.6806e-01, 2.0534e+00],
    [-5.5719e-01, -3.3391e-01, 8.4069e-03]], [[ 5.2877e-01, 1.4361e+00, -6.4232e-01],
    [ 1.0813e+00, 8.5241e-01, -1.1759e+00],
    [ 4.9389e-01, -1.7523e-01, -9.5224e-01],
    [-1.3484e+00, -5.4685e-01, 8.5539e-01],
    [ 3.7036e-01, 3.4368e-01, -4.9617e-01]], [[-2.1564e+00, 3.0861e-01, 3.4261e-01],
    [-2.3679e+00, -2.5035e-01, 1.8104e-02],
    [ 1.1075e+00, 7.2465e-01, -2.0981e-01],
    [-6.5387e-01, -1.3914e-01, 1.5205e+00],
    [-1.6561e+00, -3.5294e-01, 1.9589e+00]]])
    torch.einsum('ijk,ikl->ijl',[a,b])
    >>>tensor([[[ 1.3170, -0.7075,  1.1067],
    [-0.1569, -0.2170, -0.6309]], [[-0.1935, -1.3806, -1.1458],
    [-0.4135, 1.7577, 0.3293]], [[ 4.1854, 0.5879, -2.1180],
    [-1.4922, 0.7846, 0.7267]]])
  • 张量缩约

    batch矩阵相乘是张量缩约的一个特例,比如有两个张量,一个n阶张量\(A\in \mathbb{R}^{I_1×l_2×...×I_n}​\),一个m阶张量\(B\in \mathbb{R}^{J_1×J_2×...×J_m}​\)。取n=4,m=5,假定维度\(I_2=J_3​\)且\(I_3=J_5​\),将这两个张量在这两个维度上(A张量的第2、3维度,B张量的第3、5维度)相乘,获得新张量\(C\in \mathbb{R}^{I_1×I_4×J_1×J_2×J_4}​\),如下所示:

    \[C_{I_1×I_4×J_1×J_2×J_4}=\sum_{I_2==J_3}\sum_{I_3==J_5}A_{I_1×I_2×I_3×I_4}B_{J_1×J_2×J_3×J_4×J_5}
    \]

    a=torch.randn(2,3,5,7)
    b=torch.randn(11,13,3,17,5) torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
    >>>torch.Size([2, 7, 11, 13, 17])
  • 多张量计算

    如前所述,einsum可用于超过两个张量的计算,以双线性变换为例:

    \[D_ij=\sum_k\sum_lA_{ik}B_{jkl}C_{il}
    \]

    a=torch.randn(2,3)
    b=torch.randn(5,3,7)
    c=torch.randn(2,7) torch.einsum('ik,jkl,il->ij',[a,b,c]).shape
    >>>torch.Size([2,5])

kimiyoung/transformer-xl的tf部分大量使用了einsum表达式。

einsum满足你一切需要:深度学习中的爱因斯坦求和约定

向量点乘(内积)和叉乘(外积、向量积)概念及几何意义解读

矩阵外积与内积

外积-wiki

einsum:爱因斯坦求和约定的更多相关文章

  1. 爱因斯坦求和约定 (Einstein summation convention)

  2. MindSpore尝鲜之爱因斯坦求和

    技术背景 在前面的博客中,我们介绍过关于numpy中的张量网络的一些应用,同时利用相关的张量网络操作,我们可以实现一些分子动力学模拟中的约束算法,如LINCS等.在最新的nightly版本的MindS ...

  3. einsum函数介绍-张量常用操作

    einsum函数说明 pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和 ...

  4. NumPy v1.15手册汉化

    NumPy参考 数组创建 零 和 一 empty(shape[, dtype, order]):返回给定形状和类型的新数组,而不初始化条目 empty_like(prototype[, dtype,  ...

  5. numpy函数查询手册

    写了个程序,对Numpy的绝大部分函数及其说明进行了中文翻译. 原网址:https://docs.scipy.org/doc/numpy/reference/routines.html#routine ...

  6. NumPy之:ndarray中的函数

    NumPy之:ndarray中的函数 目录 简介 简单函数 矢量化数组运算 条件逻辑表达式 统计方法 布尔数组 排序 文件 线性代数 随机数 简介 在NumPy中,多维数组除了基本的算数运算之外,还内 ...

  7. Differential Geometry之第四章标架与曲面论的基本定理

    第四章.标架与曲面论的基本定理 1.活动标架 2.自然标架的运动方程 爱因斯坦求和约定(Einstein summation convention) 3.曲面的结构方程 4.曲面的存在唯一性定理 5. ...

  8. 记号(notation)的学习

    数学的记号(notation) 记号具体代表什么含义,取决于你的定义: 比如这样的 d⃗  一个向量,每个分量 d(i) 表示的是从初始结点 v 到当前节点 vi 的最短路径:也即这样的一个向量的每一 ...

  9. 如何基于MindSpore实现万亿级参数模型算法?

    摘要:近来,增大模型规模成为了提升模型性能的主要手段.特别是NLP领域的自监督预训练语言模型,规模越来越大,从GPT3的1750亿参数,到Switch Transformer的16000亿参数,又是一 ...

随机推荐

  1. 【b501】谁拿了最多的奖学金

    Time Limit: 1 second Memory Limit: 50 MB [问题描述] 某校的惯例是在每学期的期末考试之后发放奖学金.发放的奖学金共有五种,获取的条件各自不同:1) 院士奖学金 ...

  2. [GeekBand] 设计模式之观察者模式学习笔记

    本文参考文献::GeekBand课堂内容,授课老师:李建忠 :网络资料: http://blog.csdn.net/hguisu/article/details/7556625 本文仅作为自己的学习笔 ...

  3. linux的开机启动过程:

    简单视图 按下电源开关 开机自检(BIOS)弹笔记本logo的时候,检查cpu 硬盘 这些硬件问题 MBR引导 一般是通过硬盘启动系统 GRUB的菜单 黑底白字有个倒数计时 可以选择内核 yum命令可 ...

  4. Swift API设计原则

    注: 本文摘自 Swift API设计指南 一.基本原则 通俗易懂的API是设计者最重要的目标.实体.变量.函数等都具有一次申明.重复使用的性质,所以一个好的API设计,应该能够使用少量的解读和示例就 ...

  5. Android菜鸟的成长笔记(23)——获取网络和SIM卡信息

    TelephonyManager是一个管理手机通话状态.电话网络信息的服务类,该类提供了大量的getXxx()方法来获取电话网络的相关信息.这些信息包括设备编号.软件版本.网络运营商代号.网络运营商名 ...

  6. CSS拾遗(一)

    重新看<精通CSS(第二版)>做一些记录,方便今后巩固. 1.外边距叠加 只有普通文档流中块框的垂直外边距才会发生外边距叠加.行内框.浮动框.或绝对定位框之间的外边距不会叠加. 2.相对定 ...

  7. Hibernate——(1)Hibernate入门

    一.Hibernate简介 1.Hibernate是一款ORM框架,Object Relation Mapping 对象关系映射. 2.可以将DB映射成Object,这样程序只要对Object对象进行 ...

  8. Android客户端后台发送邮件(JMail)

    今天在做项目的时候要处理用户注册问题,里面有个邮箱验证,网上找了一下果然有人做过,但是我拿过来都运行不起来,或者是发送不了邮件.后来我对这个浅浅的研究了一下,贴出来和大家共享. Activity pa ...

  9. 度小于所述过程:KanboxEnt.exe

    在防火墙管理.见未知的过程"KanboxEnt.exe" 程序信息: 版权声明:本文博主原创文章.博客,未经同意不得转载.

  10. oracle 10g提升cluster失败

    一个今天升级10g集群环境到10.2.0.5.下载补丁p8202632_10205_Linux-x86-64.zip,解压安装并运行后.中途岛错误: I/O ERROR cannt reading o ...