einsum函数说明

pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。

# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):

  • 自由索引,出现在箭头右边的索引
  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,

接着是介绍三条基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

两条特殊规则:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子

单操作数

获取对角线元素diagonal

einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。

\[A_i = B_{ii}
\]

上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)

迹trace

求解矩阵的迹(trace),即对角线元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii}
\]

t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->或省略箭头ii

torch.einsum('ii', torch.randn(4, 4))

矩阵转置

\[A_{ij} = B_{ji}
\]

A 和 B 都是二维方阵。einsum 可以表达为 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4]) # 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.]) # 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.

双操作数

矩阵乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}
\]

第一个学习的 einsum 表达式是,ik,kj->ij。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b

矩阵-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b]) tensor([ 5., 14.])

批量矩阵乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk}
\]
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]]) # 等价操作
torch.bmm(As, Bs)

向量内积 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i}
\]
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.) # 等价操作
torch.dot(a, b)

矩阵内积 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈达玛积

\[C_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])

外积 outer

\[C_{i j}=a_{i} b_{j}
\]
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b]) tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])

einsum其他规则和例子判断:

  • 输入中多次出现的字符,将被用作求和。例子,kj,ji 完整的表达式是 kj,ji->ik,矩阵乘法再相乘。
  • 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用来跳过部分维度。例子,...ij,...jk 表示 batch 矩阵乘法。
  • 在输出没有指定的情况下,省略符优先级高于普通字符。例子,b...a 完整的表达式是 b...a->...ab,可以将一个形状为 (a,b,c) 的矩阵变为形状为 (b,c,a) 的矩阵。
  • 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,i,i,i 表示将三个一维向量按位相乘,并相加。
  • 除了箭头,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为 (4,3,3) 的矩阵,表达式 ab->a 是非法的,abc-> 是合法的。

实际使用

实现multi headed attention

https://nn.labml.ai/transformers/mha.html

如何优雅地实现多头自注意力

计算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d}
\]
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh

计算attention输出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V
\]
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]

参考文献:

https://zhuanlan.zhihu.com/p/361209187

如何优雅地实现多头自注意力

https://rockt.github.io/2018/04/30/einsum **

einsum函数介绍-张量常用操作的更多相关文章

  1. Git介绍及常用操作演示(一)--技术流ken

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  2. CI 知识 :Git介绍及常用操作

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  3. python 文件操作: 文件操作的函数, 模式及常用操作.

    1.文件操作的函数: open("文件名(路径)", mode = '模式', encoding = "字符集") 2.模式: r , w , a , r+ , ...

  4. JavaScript基础DOM介绍和常用操作(5)

    day53 参考:https://www.cnblogs.com/liwenzhou/p/8011504.html JavaScript引入方式 location对象 window.location ...

  5. 简单的git入门介绍及常用操作

    集中式版本控制系统采用中央服务器上存储的所有文件和实现团队协作.但是CVCS主要缺点是中央服务器的单点故障,即故障.不幸的是,如果中央服务器宕机一小时,然后在该时段没有人可以合作.即使在最坏的情况下, ...

  6. Docker介绍及常用操作演示(一)--技术流ken

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  7. Docker介绍及常用操作演示(一)

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  8. Docker常用命令汇总,和常用操作举例

    Docker命令 docker 常用命令如下 管理命令: container 管理容器 image 管理镜像 network 管理网络 node 管理Swarm节点 plugin 管理插件 secre ...

  9. go语言之进阶篇字符串操作常用函数介绍

    下面这些函数来自于strings包,这里介绍一些我平常经常用到的函数,更详细的请参考官方的文档. 一.字符串操作常用函数介绍 1.Contains func Contains(s, substr st ...

随机推荐

  1. HTML 初学整理

    一.HTML简介 HTML的概念 HTML是HyperText Markup Language(超文本标记语言)的简写,超文本标记语言,标准通用标记语言下的一个应用."超文本"就是 ...

  2. JavaScript 工作原理之三-内存管理及如何处理 4 类常见的内存泄漏问题(译)

    原文请查阅这里,本文有进行删减,文后增了些经验总结. 本系列持续更新中,Github 地址请查阅这里. 这是 JavaScript 工作原理的第三章. 我们将会讨论日常使用中另一个被开发者越来越忽略的 ...

  3. number(10,6)正则表达式

    /**     * 判断number(10,6)     * @param dateStr     * @return     */    public boolean isNumJW(String ...

  4. 浅谈ES6中的Class

    转载地址:https://www.cnblogs.com/sghy/p/8005857.html 一.定义类(ES6的类,完全可以看做是构造函数的另一种写法) class Greet { constr ...

  5. 无需debug,通过抽象模型快速梳理代码核心流程

    上一篇我们通过DSM来确定了核心对象并构建了抽象模型.本篇是<如何高效阅读源码>专题的第八篇,我们来基于抽象模型来梳理核心流程. 本节主要内容: 如何通过抽象模型来梳理核心流程 从类名和注 ...

  6. javaWeb代码整理03-druid数据库连接池

    jar包: maven坐标: <dependency> <groupId>com.alibaba</groupId> <artifactId>druid ...

  7. 《手把手教你》系列基础篇(九十二)-java+ selenium自动化测试-框架设计基础-POM设计模式简介(详解教程)

    1.简介 页面对象模型(Page Object Model)在Selenium Webdriver自动化测试中使用非常流行和受欢迎,作为自动化测试工程师应该至少听说过POM这个概念.本篇介绍POM的简 ...

  8. 关于visualvm无法监控本地java进程

    https://blog.csdn.net/weixin_43827693/article/details/105990675?spm=1001.2101.3001.6661.1&utm_me ...

  9. 关于 background-image 渐变gradient()那些事!

    大家好,我是半夏,一个刚刚开始写文的沙雕程序员.如果喜欢我的文章,可以关注 点赞 加我微信:frontendpicker,一起学习交流前端,成为更优秀的工程师-关注公众号:搞前端的半夏,了解更多前端知 ...

  10. vmware ubuntu 忘记密码

    1.进入recovery模式 修改root密码 1.启动ubuntu系统,一开始在有进度条的时候按下shift键,出现GRUB选择菜单,选择Advanced options for Ubuntu 按回 ...