在Tensorflow、Numpy和PyTorch中都提供了使用einsum的api,einsum是一种能够简洁表示点积、外积、转置、矩阵-向量乘法、矩阵-矩阵乘法等运算的领域特定语言。在Tensorflow等计算框架中使用einsum,操作矩阵运算时可以免于记忆和使用特定的函数,并且使得代码简洁,高效。

如对矩阵\(A\in \mathbb{R}^{I×K}​\)和矩阵\(B\in \mathbb{R}^{K×J}​\)做矩阵乘,然后对列求和,最终得到向量\(c\in \mathbb{R}^J​\),即:

\[\mathbb{R}^{I×K}\bigotimes \mathbb{R}^{K×J}\to \mathbb{R}^{I×J}\to \mathbb{R}^{J}
\]

使用爱因斯坦求和约定表示为:

\[c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj}
\]

在Tensorflow、Numpy和PyTorch中对应的einsum字符串为:

ik,kj->j

在上面的字符串中,隐式地省略了重复的下标\(k\),表示在该维度矩阵乘;另外输出中未指明下标\(i\),表示在该维度累加。

Numpy、PyTorch和Tensorflow中的einsum

einsum在Numpy中的实现为np.einsum,在PyTorch中的实现为torch.einsum,在Tensorflow中的实现为tf.einsum,均使用同样的函数签名einsum(equation,operands),其中,equation传入爱因斯坦求和约定的字符串,而operands则是张量序列。在Numpy、Tensorflow中是变长参数列表,而在PyTorch中是列表。上述例子中,在Tensorflow中可写作:

tf.einsum('ik,kj->j',mat1,mat2)

其中,mat1、mat2为执行该运算的两个张量。注意:这里的(i,j,k)的命名是任意的,但在一个表达式中要一致。

PyTorch和Tensorflow像Numpy支持einsum的好处之一就是,einsum可以用于深度网络架构的任意计算图,并且可以反向传播。在Numpy和Tensorflow中的调用格式如下:

\[result=\mathop{einsum}('\square \square, \square \square \square,\square \square\to \square \square',arg1,arg2,arg3)
\]

其中,\(\square\)是占位符,表示张量维度;arg1,arg3是矩阵,arg2是三阶张量,运算结果是矩阵。注意:einsum处理可变数量的输入。上面例子中,einsum制定了三个参数的操作,但同样可以操作一个参数、两个参数和三个参数及以上的操作。

典型的einsum表达式

前置知识

  • 内积

    又称点积、点乘,对应位置数字相乘,结果是一个标量,有见向量内积和矩阵内积等。

    向量\(\vec a\)和向量\(\vec b\)的内积:

    \[\vec a=[a_1,a_2,...,a_n]\\
    \vec b=[b_1,b_2,...,b_n]\\
    \vec a\cdot \vec b^T=a_1b_1+a_2b_2+...+a_nb_n
    \]

    内积几何意义:

    \[\vec a \cdot \vec b^T=|\vec a||\vec b|\mathop{cos}\theta
    \]

  • 外积

    又称叉乘、叉积、向量积,行向量矩阵乘列向量,结果是二阶张量。注意到:张量的外积作为张量积的同义词。外积是一种特殊的克罗内克积。

    向量\(\vec a\)和向量\(\vec b\)的外积:

    \[\begin{bmatrix}
    b_1
    \\b_2
    \\ b_3
    \\ b_4
    \end{bmatrix}\bigotimes[a_1,a_2,a_3]=\begin{bmatrix}
    a_1b_1 & a_2b_1 & a_3b_1 \\
    a_1b_2 & a_2b_2 & a_3b_2 \\
    a_1b_3 & a_2b_3 & a_3b_3 \\
    a_1b_4 & a_2b_4 & a_3b_4 \\
    \end{bmatrix}
    \]

    外积的几何意义:

    \[\vec a=(x_1,y_1,z_1)\\
    \vec b=(x_2,y_2,z_2)\\
    \vec a\bigotimes\vec b=\begin{vmatrix}
    i & j & k\\
    x_1 & y_1 & z_1\\
    x_2 & y_2 & z_2
    \end{vmatrix}=(y_1z_2-y_2z_1)\vec i-(x_1z_2-x_2z_1)\vec j+(x_1y_2-x_2y_1)\vec k
    \]

    其中,

    \[\vec i=(1,0,0)\\
    \vec j=(0,1,0)\\
    \vec k=(0,0,1)
    \]

由于PyTorch可以实时输出运算结果,以PyTorch使用einsum表达式为例。

  • 矩阵转置

    \[B_{ji}=A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->ji',[a])
    >>>tensor([[0, 3],
    [1, 4],
    [2, 5]])
  • 求和

    \[b=\sum_{i}\sum_{j}A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->',[a])
    >>>tensor(15)
  • 列求和(列维度不变,行维度消失)

    \[b_j=\sum_iA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->j',[a])
    >>>tensor([ 3.,  5.,  7.])
  • 列求和(列维度不变,行维度消失)

    \[b_i=\sum_jA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->i', [a])
    >>>tensor([  3.,  12.])
  • 矩阵-向量相乘

    \[c_i=\sum_k A_{ik}b_k
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ik,k->i',[a,b])
    >>>tensor([  5.,  14.])
  • 矩阵-矩阵乘法

    \[C_{ij}=\sum_{k}A_{ik}B_{kj}
    \]

    a=torch.arange(6).reshape(2,3)
    b=torch.arange(15).reshape(3,5)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]]) >>>tensor([[ 0, 1, 2, 3, 4],
    [ 5, 6, 7, 8, 9],
    [10, 11, 12, 13, 14]])
    torch.einsum('ik,kj->ij',[a,b])
    >>>tensor([[ 25,  28,  31,  34,  37],
    [ 70, 82, 94, 106, 118]])
  • 点积

    • 向量

      \[c=\sum_i a_i b_i
      \]

      a=torch.arange(3)
      b=torch.arange(3,6)
      >>>tensor([0, 1, 2])
      >>>tensor([3, 4, 5])
      torch.einsum('i,i->',[a,b])
      >>>tensor(14.)
    • 矩阵

      \[c=\sum_i\sum_j A_{ij}B_{ij}
      \]

      a=torch.arange(6).reshape(2,3)
      b=torch.arange(6,12).reshape(2,3)
      >>>tensor([[0, 1, 2],
      [3, 4, 5]]) >>>tensor([[ 6, 7, 8],
      [ 9, 10, 11]])
      torch.einsum('ij,ij->',[a,b])
      >>>tensor(145.)
  • 外积

    \[C_{ij}=a_i b_j
    \]

    a=torch.arange(3)
    b=torch.arange(3,7)
    >>>tensor([0, 1, 2])
    >>>tensor([3, 4, 5, 6])
    torch.einsum('i,j->ij',[a,b])
    >>>tensor([[  0.,   0.,   0.,   0.],
    [ 3., 4., 5., 6.],
    [ 6., 8., 10., 12.]])
  • batch矩阵乘

    \[C_{ijl}=\sum_{k}A_{ijk}B_{ikl}
    \]

    a=torch.randn(3,2,5)
    b=torch.randn(3,5,3)
    >>>tensor([[[-1.4131e+00,  3.8372e-02,  1.2436e+00,  5.4757e-01,  2.9478e-01],
    [ 1.3314e+00, 4.4003e-01, 2.3410e-01, -5.3948e-01, -9.9714e-01]], [[-4.6552e-01, 5.4318e-01, 2.1284e+00, 9.5029e-01, -8.2193e-01],
    [ 7.0617e-01, 9.8252e-01, -1.4406e+00, 1.0071e+00, 5.9477e-01]], [[-1.0482e+00, 4.7110e-02, 1.0014e+00, -6.0593e-01, -3.2076e-01],
    [ 6.6210e-01, 3.7603e-01, 1.0198e+00, 4.6591e-01, -7.0637e-04]]]) >>>tensor([[[-2.1797e-01, 3.1329e-04, 4.3139e-01],
    [-1.0621e+00, -6.0904e-01, -4.6225e-01],
    [ 8.5050e-01, -5.8867e-01, 4.8824e-01],
    [ 2.8561e-01, 2.6806e-01, 2.0534e+00],
    [-5.5719e-01, -3.3391e-01, 8.4069e-03]], [[ 5.2877e-01, 1.4361e+00, -6.4232e-01],
    [ 1.0813e+00, 8.5241e-01, -1.1759e+00],
    [ 4.9389e-01, -1.7523e-01, -9.5224e-01],
    [-1.3484e+00, -5.4685e-01, 8.5539e-01],
    [ 3.7036e-01, 3.4368e-01, -4.9617e-01]], [[-2.1564e+00, 3.0861e-01, 3.4261e-01],
    [-2.3679e+00, -2.5035e-01, 1.8104e-02],
    [ 1.1075e+00, 7.2465e-01, -2.0981e-01],
    [-6.5387e-01, -1.3914e-01, 1.5205e+00],
    [-1.6561e+00, -3.5294e-01, 1.9589e+00]]])
    torch.einsum('ijk,ikl->ijl',[a,b])
    >>>tensor([[[ 1.3170, -0.7075,  1.1067],
    [-0.1569, -0.2170, -0.6309]], [[-0.1935, -1.3806, -1.1458],
    [-0.4135, 1.7577, 0.3293]], [[ 4.1854, 0.5879, -2.1180],
    [-1.4922, 0.7846, 0.7267]]])
  • 张量缩约

    batch矩阵相乘是张量缩约的一个特例,比如有两个张量,一个n阶张量\(A\in \mathbb{R}^{I_1×l_2×...×I_n}​\),一个m阶张量\(B\in \mathbb{R}^{J_1×J_2×...×J_m}​\)。取n=4,m=5,假定维度\(I_2=J_3​\)且\(I_3=J_5​\),将这两个张量在这两个维度上(A张量的第2、3维度,B张量的第3、5维度)相乘,获得新张量\(C\in \mathbb{R}^{I_1×I_4×J_1×J_2×J_4}​\),如下所示:

    \[C_{I_1×I_4×J_1×J_2×J_4}=\sum_{I_2==J_3}\sum_{I_3==J_5}A_{I_1×I_2×I_3×I_4}B_{J_1×J_2×J_3×J_4×J_5}
    \]

    a=torch.randn(2,3,5,7)
    b=torch.randn(11,13,3,17,5) torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
    >>>torch.Size([2, 7, 11, 13, 17])
  • 多张量计算

    如前所述,einsum可用于超过两个张量的计算,以双线性变换为例:

    \[D_ij=\sum_k\sum_lA_{ik}B_{jkl}C_{il}
    \]

    a=torch.randn(2,3)
    b=torch.randn(5,3,7)
    c=torch.randn(2,7) torch.einsum('ik,jkl,il->ij',[a,b,c]).shape
    >>>torch.Size([2,5])

kimiyoung/transformer-xl的tf部分大量使用了einsum表达式。

einsum满足你一切需要:深度学习中的爱因斯坦求和约定

向量点乘(内积)和叉乘(外积、向量积)概念及几何意义解读

矩阵外积与内积

外积-wiki

einsum:爱因斯坦求和约定的更多相关文章

  1. 爱因斯坦求和约定 (Einstein summation convention)

  2. MindSpore尝鲜之爱因斯坦求和

    技术背景 在前面的博客中,我们介绍过关于numpy中的张量网络的一些应用,同时利用相关的张量网络操作,我们可以实现一些分子动力学模拟中的约束算法,如LINCS等.在最新的nightly版本的MindS ...

  3. einsum函数介绍-张量常用操作

    einsum函数说明 pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和 ...

  4. NumPy v1.15手册汉化

    NumPy参考 数组创建 零 和 一 empty(shape[, dtype, order]):返回给定形状和类型的新数组,而不初始化条目 empty_like(prototype[, dtype,  ...

  5. numpy函数查询手册

    写了个程序,对Numpy的绝大部分函数及其说明进行了中文翻译. 原网址:https://docs.scipy.org/doc/numpy/reference/routines.html#routine ...

  6. NumPy之:ndarray中的函数

    NumPy之:ndarray中的函数 目录 简介 简单函数 矢量化数组运算 条件逻辑表达式 统计方法 布尔数组 排序 文件 线性代数 随机数 简介 在NumPy中,多维数组除了基本的算数运算之外,还内 ...

  7. Differential Geometry之第四章标架与曲面论的基本定理

    第四章.标架与曲面论的基本定理 1.活动标架 2.自然标架的运动方程 爱因斯坦求和约定(Einstein summation convention) 3.曲面的结构方程 4.曲面的存在唯一性定理 5. ...

  8. 记号(notation)的学习

    数学的记号(notation) 记号具体代表什么含义,取决于你的定义: 比如这样的 d⃗  一个向量,每个分量 d(i) 表示的是从初始结点 v 到当前节点 vi 的最短路径:也即这样的一个向量的每一 ...

  9. 如何基于MindSpore实现万亿级参数模型算法?

    摘要:近来,增大模型规模成为了提升模型性能的主要手段.特别是NLP领域的自监督预训练语言模型,规模越来越大,从GPT3的1750亿参数,到Switch Transformer的16000亿参数,又是一 ...

随机推荐

  1. 【codeforces 782D】 Innokenty and a Football League

    [题目链接]:http://codeforces.com/contest/782 [题意] 每个队名有两种选择, 然后第一个选择队名相同的那些队只能选第二种; 让你安排队名 [题解] 首先全都选成第一 ...

  2. 链表与哈希表基本概念及Java常用集合

    -链表- 是一种物理存储单元上非连续.非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的.链表由一系列结点(链表中每一个元素称为结点)组成,结点可以在运行时动态生成.每个结点包括两个 ...

  3. ajax——XMLHttpRequest

    XMLHttpRequest对象.能够让ajax程序在不又一次载入的页面的情况下更新页面数据,页面载入完毕后从server接受发生数据.这样既减轻了server负担又回顾了响应速度,缩短了用户的等待时 ...

  4. SystemServer概述

    SystemServer由Zygote fork生成的,进程名为system_server,该进程承载着framework的核心服务. 调用流程如下: 上图前4步骤(即颜色为紫色的流程)运行在是Zyg ...

  5. CentOS 配置epel源

    先查询下有没有epel rpm -qa|grep epel 没有的话到官网https://fedoraproject.org/wiki/EPEL下载rpm包 然后 rpm -ivh 安装 安装完毕后到 ...

  6. Vue Router的官方示例改造

    基于Vue Router 2018年8月的官方文档示例,改造一下,通过一个最简单的例子,解决很多初学者的一个困惑. 首先是官方文档示例代码 <!DOCTYPE html> <html ...

  7. java开发环境配置(windows下JDK7+tomcat7)

    參考原文:http://www.cnblogs.com/goto/archive/2012/11/16/2772683.html http://www.cnblogs.com/feilong35407 ...

  8. telnet 的使用(ping 与 telnet)

    基本用法 >> telnet localhost 23 // 23 表示 telnet 服务的端口号,不写端口号也可以,telnet 默认绑定的端口号就是 23 // netstat -a ...

  9. 在CSDN博客中添加量子恒道统计功能的做法

    作者:朱金灿 来源:http://blog.csdn.net/clever101 什么是量子恒道统计?量子恒道统计是一套免费的网站流量统计分析系统.致力于为所有个人站长.个人博主.所有网站管理者.第三 ...

  10. WPF多线程UI更新

    前言 在WPF中,在使用多线程在后台进行计算限制的异步操作的时候,如果在后台线程中对UI进行了修改,则会出现一个错误:(调用线程无法访问此对象,因为另一个线程拥有该对象.)这是很常见的一个错误,一不小 ...