einsum函数说明

pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。

# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):

  • 自由索引,出现在箭头右边的索引
  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,

接着是介绍三条基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

两条特殊规则:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子

单操作数

获取对角线元素diagonal

einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。

\[A_i = B_{ii}
\]

上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)

迹trace

求解矩阵的迹(trace),即对角线元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii}
\]

t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->或省略箭头ii

torch.einsum('ii', torch.randn(4, 4))

矩阵转置

\[A_{ij} = B_{ji}
\]

A 和 B 都是二维方阵。einsum 可以表达为 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4]) # 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.]) # 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.

双操作数

矩阵乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}
\]

第一个学习的 einsum 表达式是,ik,kj->ij。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b

矩阵-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b]) tensor([ 5., 14.])

批量矩阵乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk}
\]
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]]) # 等价操作
torch.bmm(As, Bs)

向量内积 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i}
\]
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.) # 等价操作
torch.dot(a, b)

矩阵内积 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈达玛积

\[C_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])

外积 outer

\[C_{i j}=a_{i} b_{j}
\]
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b]) tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])

einsum其他规则和例子判断:

  • 输入中多次出现的字符,将被用作求和。例子,kj,ji 完整的表达式是 kj,ji->ik,矩阵乘法再相乘。
  • 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用来跳过部分维度。例子,...ij,...jk 表示 batch 矩阵乘法。
  • 在输出没有指定的情况下,省略符优先级高于普通字符。例子,b...a 完整的表达式是 b...a->...ab,可以将一个形状为 (a,b,c) 的矩阵变为形状为 (b,c,a) 的矩阵。
  • 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,i,i,i 表示将三个一维向量按位相乘,并相加。
  • 除了箭头,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为 (4,3,3) 的矩阵,表达式 ab->a 是非法的,abc-> 是合法的。

实际使用

实现multi headed attention

https://nn.labml.ai/transformers/mha.html

如何优雅地实现多头自注意力

计算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d}
\]
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh

计算attention输出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V
\]
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]

参考文献:

https://zhuanlan.zhihu.com/p/361209187

如何优雅地实现多头自注意力

https://rockt.github.io/2018/04/30/einsum **

einsum函数介绍-张量常用操作的更多相关文章

  1. Git介绍及常用操作演示(一)--技术流ken

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  2. CI 知识 :Git介绍及常用操作

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  3. python 文件操作: 文件操作的函数, 模式及常用操作.

    1.文件操作的函数: open("文件名(路径)", mode = '模式', encoding = "字符集") 2.模式: r , w , a , r+ , ...

  4. JavaScript基础DOM介绍和常用操作(5)

    day53 参考:https://www.cnblogs.com/liwenzhou/p/8011504.html JavaScript引入方式 location对象 window.location ...

  5. 简单的git入门介绍及常用操作

    集中式版本控制系统采用中央服务器上存储的所有文件和实现团队协作.但是CVCS主要缺点是中央服务器的单点故障,即故障.不幸的是,如果中央服务器宕机一小时,然后在该时段没有人可以合作.即使在最坏的情况下, ...

  6. Docker介绍及常用操作演示(一)--技术流ken

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  7. Docker介绍及常用操作演示(一)

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  8. Docker常用命令汇总,和常用操作举例

    Docker命令 docker 常用命令如下 管理命令: container 管理容器 image 管理镜像 network 管理网络 node 管理Swarm节点 plugin 管理插件 secre ...

  9. go语言之进阶篇字符串操作常用函数介绍

    下面这些函数来自于strings包,这里介绍一些我平常经常用到的函数,更详细的请参考官方的文档. 一.字符串操作常用函数介绍 1.Contains func Contains(s, substr st ...

随机推荐

  1. javaweb之删除功能

    对数据库的删除,主要是通过表中的一个数据查询来进行逐个删除,否则会清空整张表. 一.dao层 在dao层加入删除方法 public boolean delete(Course n) { boolean ...

  2. c++实现状态模式

    实验:用Java代码模拟实现课堂上的"银行账户"的实例,要求编写客户端测试代码模拟用户存款和取款,注意账户对象状态和行为的变化. 由于是c++,不像java那么灵活,所以类的调用方 ...

  3. js获取url查询字符串参数

    最近看js高级程序设计 对其中查询字符串参数的获得重新写了,当传递一个完整的URL的时候对查询字符串的提取 function getQueryArgs(){ var qs = (location.se ...

  4. Python程序的流程

    1 """ 2 python程序的流程 3 """ 4 # ------------- 分支结构---------------- 5 # i ...

  5. Js中的三个错误语句:try、catch、throw

    Js中的三个错误语句:try.catch.throw

  6. 2022-Aech安装(详细)

    官方wiki:https://wiki.archlinux.org/ 基础安装 一:制作安装介质 下载ISO镜像文件: https://archlinux.org/download/ # 官方下载网址 ...

  7. 企业级 Web 开发的挑战

    本文翻译自土牛Halil ibrahim Kalkan的<Mastering ABP Framework>,是系列翻译的起头,适合ABP开发人员或者想对ABP框架进行深入演进的准架构师. ...

  8. LCA的离线快速求法

    最常见的LCA(树上公共祖先)都是在线算法,往往带了一个log.有一种办法是转化为"+-1最值问题"得到O(n)+O(1)的复杂度,但是原理复杂,常数大.今天介绍一种允许离线时接近 ...

  9. Http GET 请求参数中文乱码

    两种解决方式 第1种:代码里转换 String name = request.getParamter("name"); String nameUtf8 = new String(n ...

  10. redis:缓存穿透、缓存击穿、缓存雪崩

    缓存穿透的解决方案(空标记) 缓存穿透是指,在数据存储系统中不存在的记录,不会被存储到缓存中.这种记录每次的查询流量都会穿透到数据存储层.在高流量的场景下,不断查询空结果会大量消耗数据查询服务的资源, ...