einsum函数说明

pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。

# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):

  • 自由索引,出现在箭头右边的索引
  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,

接着是介绍三条基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

两条特殊规则:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子

单操作数

获取对角线元素diagonal

einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。

\[A_i = B_{ii}
\]

上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)

迹trace

求解矩阵的迹(trace),即对角线元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii}
\]

t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->或省略箭头ii

torch.einsum('ii', torch.randn(4, 4))

矩阵转置

\[A_{ij} = B_{ji}
\]

A 和 B 都是二维方阵。einsum 可以表达为 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4]) # 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.]) # 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.

双操作数

矩阵乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}
\]

第一个学习的 einsum 表达式是,ik,kj->ij。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b

矩阵-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b]) tensor([ 5., 14.])

批量矩阵乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk}
\]
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]]) # 等价操作
torch.bmm(As, Bs)

向量内积 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i}
\]
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.) # 等价操作
torch.dot(a, b)

矩阵内积 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈达玛积

\[C_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])

外积 outer

\[C_{i j}=a_{i} b_{j}
\]
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b]) tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])

einsum其他规则和例子判断:

  • 输入中多次出现的字符,将被用作求和。例子,kj,ji 完整的表达式是 kj,ji->ik,矩阵乘法再相乘。
  • 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用来跳过部分维度。例子,...ij,...jk 表示 batch 矩阵乘法。
  • 在输出没有指定的情况下,省略符优先级高于普通字符。例子,b...a 完整的表达式是 b...a->...ab,可以将一个形状为 (a,b,c) 的矩阵变为形状为 (b,c,a) 的矩阵。
  • 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,i,i,i 表示将三个一维向量按位相乘,并相加。
  • 除了箭头,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为 (4,3,3) 的矩阵,表达式 ab->a 是非法的,abc-> 是合法的。

实际使用

实现multi headed attention

https://nn.labml.ai/transformers/mha.html

如何优雅地实现多头自注意力

计算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d}
\]
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh

计算attention输出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V
\]
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]

参考文献:

https://zhuanlan.zhihu.com/p/361209187

如何优雅地实现多头自注意力

https://rockt.github.io/2018/04/30/einsum **

einsum函数介绍-张量常用操作的更多相关文章

  1. Git介绍及常用操作演示(一)--技术流ken

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  2. CI 知识 :Git介绍及常用操作

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  3. python 文件操作: 文件操作的函数, 模式及常用操作.

    1.文件操作的函数: open("文件名(路径)", mode = '模式', encoding = "字符集") 2.模式: r , w , a , r+ , ...

  4. JavaScript基础DOM介绍和常用操作(5)

    day53 参考:https://www.cnblogs.com/liwenzhou/p/8011504.html JavaScript引入方式 location对象 window.location ...

  5. 简单的git入门介绍及常用操作

    集中式版本控制系统采用中央服务器上存储的所有文件和实现团队协作.但是CVCS主要缺点是中央服务器的单点故障,即故障.不幸的是,如果中央服务器宕机一小时,然后在该时段没有人可以合作.即使在最坏的情况下, ...

  6. Docker介绍及常用操作演示(一)--技术流ken

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  7. Docker介绍及常用操作演示(一)

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  8. Docker常用命令汇总,和常用操作举例

    Docker命令 docker 常用命令如下 管理命令: container 管理容器 image 管理镜像 network 管理网络 node 管理Swarm节点 plugin 管理插件 secre ...

  9. go语言之进阶篇字符串操作常用函数介绍

    下面这些函数来自于strings包,这里介绍一些我平常经常用到的函数,更详细的请参考官方的文档. 一.字符串操作常用函数介绍 1.Contains func Contains(s, substr st ...

随机推荐

  1. 【001】学习前提——安装linux虚拟机,搭建docker

    1. 配置linux 1.1 修改配置 安装virtualbox的过程略过. 进入cd /etc/sysconfig/network-scripts,编辑:vi ifcfg-enp0s3 1>将 ...

  2. java语言和jdk、jre基础

    Java语言平台 * J2SE(Java 2 Platform Standard Edition)标准版  * 是为开发普通桌面和商务应用程序提供的解决方案,该技术体系是其他两者的基础,可以完成一些桌 ...

  3. echarts中boundaryGap属性

    boundaryGap:false boundaryGap:true 代码处: xAxis: { type: "category", data: ["06-01" ...

  4. vue实现省市区三级联动

    npm 安装 npm install v-distpicker --save Vue全局引入组件 import Distpicker from 'v-distpicker' Vue.component ...

  5. DRF-认证权限频率

    目录 DRF-认证权限频率 认证 登录接口 认证 权限 作用 使用 频率 作用 使用 认证权限频率+五个接口 模型 视图 序列化器 认证权限频率类 配置文件 路由 DRF-认证权限频率 前后端混合开发 ...

  6. maven导入依赖了提示can't resolved

    maven导入依赖显红报错 网上有很多解决方案,我试过几个但是都不是很好用,推荐一个我自己一直在用的解决方案 在终端执行命令 mvn idea:idea 无法解析的原因基本上是因为包没下载完整,执行这 ...

  7. sqli-labs环境搭建

    1 下载phpStudy 下载地址:https://www.xp.cn/download.html 由于sqli-lib最后一次提交代码的时候是2014年,所以高版本的phpStudy可能不兼容了,推 ...

  8. python黑帽子(第四章)

    Scapy窃取ftp登录账号密码 sniff函数的参数 filter 过滤规则,默认是嗅探所有数据包,具体过滤规则与wireshark相同. iface 参数设置嗅探器索要嗅探的网卡,默认对所有的网卡 ...

  9. Hyperledger Fabric 2.2 学习笔记:测试网络test-network

    写在前面 最近被Hyperledger Fabric折磨,归根结底还是因为自己太菜了qwq.学习路漫漫,笔记不能少.下面的步骤均是基于已经成功搭建了Fabric2.2环境,并且拉取fabric-sam ...

  10. 攻防世界-MISC:pure_color

    这是攻防世界高手进阶区的第六题,题目如下: 点击下载附件一,得到一张空白的png图片 用StegSolve打开,然后点击箭头(如下图所示) 多点击几次,即可得到flag 所以,这道题的flag如下: ...