AlphaTensor论文阅读分析

目前只是大概了解了AlphaTensor的思路和效果,完善ing

deepmind博客在 https://www.deepmind.com/blog/discovering-novel-algorithms-with-alphatensor

论文是 https://www.nature.com/articles/s41586-022-05172-4

解决"如何快速计算矩阵乘法"的问题

问题建模

变成single-player game

\[\tau_n= \sum_{r=1}^R \textbf{u}^{(r)} \otimes \textbf{v}^{(r)} \otimes \textbf{w}^{(r)}
\]

In \(2*2*2\) case of Strassen, R is 7. (see the fig.c). The goal of DRL algorithm is to minimize R (i.e. total step)

the size of $\textbf{u}^{(r)} $ is \((n^2, R)\).

$ \textbf{u}^{(1)}$ is the first column of u: \((1,0,0,1)^T\)

$ \textbf{v}^{(1)}$ is the first column of v: \((1,0,0,1)^T\)

$\textbf{u}^{(1)} \otimes \textbf{v}^{(1)} = $

\[\begin{bmatrix} 1 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \\1 & 0 & 0 & 1 \end{bmatrix}\quad
\]

上面矩阵的第一行代表a1,第四行代表a4,第一列代表b1... (1,1)位置出现一个1,表示当前矩阵代表的式子里面有个\(a_1b_1\) , 上面这个矩阵对应的是m1=(a1+a4)(b1+b4)

$\textbf{u}^{(1)} \otimes \textbf{v}^{(1)} \otimes \textbf{w}^{(1)} $ 就是再结合上ci,哪些ci中包括m1这一项。最终三者外积得到的是\(n*n*n\)的张量,ci对应的\(n*n\)矩阵内记录的就是ci需要哪些ab的乘积项来组合出来。当然,最终需要R个这样的三维张量才能达到正确的矩阵乘法。

(第一步是选择mi如何由ai bi组成,这对应上面那个\(n*n\)的矩阵。第二步是选择ci如何由mi组成,这对应着\(\textbf{w}\)那个\((n^2, R)\)的矩阵。两步合在一起得到R个\(n*n*n\)的三维张量,R个三维张量加起来得到\(\tau_n\),\(\tau_n\)中挑出ci那一维,对应的矩阵就是ci如何由ai bi组成)。

按照朴素矩阵乘法,\(c_1=a_1*b_1+a_2*b_3\) ,因此,无论采用什么路径, 合计出来的三维张量\(\tau_n\),在c1这个维度上都必须是

\[\begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 \\0 & 0 & 0 & 0 \end{bmatrix}\quad
\]

因此,可以用朴素矩阵乘法算出最终的目标,即\(\tau_n\) 。

step

在step 0, \(S_0=\tau_n\). (target)

在游戏的step t, player选择一个三元组 \((u^{(t)}, v^{(t)}, w^{(t)})\) : $S_t \leftarrow S_{t-1} - \textbf{u}^{(t)} \otimes \textbf{v}^{(t)} \otimes \textbf{w}^{(t)} $

目标是用最少的步数达到zero tensor \(S_t=\vec 0\)

所以 action space 是 \(\{0,1\}^{n^2} \times \{0,1\}^{n^2} \times \{0,1\}^{n^2}\)

为了避免游戏被拉得太长: \(R \le R_{limit}\) ( \(R_{limit}\) 步之后终止)

reward:

每一个step: -1 reward (为了找到最短路)

如果在non-zero tensor终止: \(-\gamma(S_{R_{limit}})\) reward

(\(\gamma(S_{R_{limit}})\) 是terminal tensor的rank的上界)

constrain \(\{u^{(t)}, v^{(t)}, w^{(t)}\}\) in a user-specified discrete set of coeffients F

AlphaTensor

有些类似于 AlphaZero

  • 一个deep nn 去指导 MCTS.
  • state作为输入, policy (action上的一个概率分布) 和 value作为输出

算出最优策略下每一步的action: \(\{(u^{(r)}, v^{(r)}, w^{(r)})\}^R_{r=1}\) 之后,就可以拿uvw用于矩阵乘法了

效果

可以看到,AlphaTensor搜索出来的计算方法,在部分矩阵规模上达到了更优的结果,即乘法次数更少。

在第四行,(5,5,5)情形下的矩阵乘法,AlphaTensor计算出来的方法可以在博客里面看到,非常复杂,为了减少两次乘法,却耗费了数几十次加法。因此AlphaTensor只能做到渐进时间复杂度更优,在大矩阵情形下达到更快的速度。

值得关注的是,他们在\(8192*8192\)的方阵乘法上进行了测试,采用\(4*4\)分块的方式(这样每个子矩阵的大小就是\(2048*2048\)规模的了),AlphaTensor方法比Strassen的方法减少了两次矩阵乘法,因此加速比从1.043提升至1.085。这说明这一方法相比coppersmith-winograd方法(\(O(n^{2.37})\))那种银河算法更加实用,常数更低,在8192规模的矩阵就能生效了。而且,计算矩阵乘法的Algorithm 1也方便在GPU和TPU上并行。

AlphaTensor论文阅读分析的更多相关文章

  1. BITED数学建模七日谈之三:怎样进行论文阅读

    前两天,我和大家谈了如何阅读教材和备战数模比赛应该积累的内容,本文进入到数学建模七日谈第三天:怎样进行论文阅读. 大家也许看过大量的数学模型的书籍,学过很多相关的课程,但是若没有真刀真枪地看过论文,进 ...

  2. 【医学图像】3D Deep Leaky Noisy-or Network 论文阅读(转)

    文章来源:https://blog.csdn.net/u013058162/article/details/80470426 3D Deep Leaky Noisy-or Network 论文阅读 原 ...

  3. Event StoryLine Corpus 论文阅读

    Event StoryLine Corpus 论文阅读 本文是对 Caselli T, Vossen P. The event storyline corpus: A new benchmark fo ...

  4. 论文阅读:Face Recognition: From Traditional to Deep Learning Methods 《人脸识别综述:从传统方法到深度学习》

     论文阅读:Face Recognition: From Traditional to Deep Learning Methods  <人脸识别综述:从传统方法到深度学习>     一.引 ...

  5. 论文阅读:《Bag of Tricks for Efficient Text Classification》

    论文阅读:<Bag of Tricks for Efficient Text Classification> 2018-04-25 11:22:29 卓寿杰_SoulJoy 阅读数 954 ...

  6. Nature/Science 论文阅读笔记

    Nature/Science 论文阅读笔记 Unsupervised word embeddings capture latent knowledge from materials science l ...

  7. [论文阅读]阿里DIN深度兴趣网络之总体解读

    [论文阅读]阿里DIN深度兴趣网络之总体解读 目录 [论文阅读]阿里DIN深度兴趣网络之总体解读 0x00 摘要 0x01 论文概要 1.1 概括 1.2 文章信息 1.3 核心观点 1.4 名词解释 ...

  8. [论文阅读]阿里DIEN深度兴趣进化网络之总体解读

    [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 目录 [论文阅读]阿里DIEN深度兴趣进化网络之总体解读 0x00 摘要 0x01论文概要 1.1 文章信息 1.2 基本观点 1.2.1 DIN的 ...

  9. [论文阅读] RNN 在阿里DIEN中的应用

    [论文阅读] RNN 在阿里DIEN中的应用 0x00 摘要 本文基于阿里推荐DIEN代码,梳理了下RNN一些概念,以及TensorFlow中的部分源码.本博客旨在帮助小伙伴们详细了解每一步骤以及为什 ...

随机推荐

  1. 【Azure 事件中心】使用Azure AD认证方式创建Event Hub Consume Client + 自定义Event Position

    问题描述 当使用SDK连接到Azure Event Hub时,最常规的方式为使用连接字符串.这种做法参考官网文档就可成功完成代码:https://docs.azure.cn/zh-cn/event-h ...

  2. SAM复杂度证明

    关于$SAM$的复杂度证明(大部分是对博客的我自己的理解和看法) 这部分是我的回忆,可省略 先回忆一下$SAM$ 我所理解的$SAM$,首先扒一张图 初始串$aabbabd$ 首先发现,下图里的$S- ...

  3. AgileFontSet迅捷字体设置程序

    AgileFontSet迅捷字体设置程序-用户手册  AgileFontSet的完整代码,参见 https://www.cnblogs.com/ybmj/p/11683291.html 1.程序特点和 ...

  4. SpringBoot接收MultipartFile文件,并保存文件

    package com.hrw.controller; import com.hrw.pojo.Result; import org.apache.tomcat.util.http.fileuploa ...

  5. Docke 搭建 apache2 + php8 + MySQL8 环境

    Docker 安装 执行 Docker 安装命令 curl -fsSL https://get.docker.com/ | sh 启动 Docker 服务 sudo service docker st ...

  6. 【NOI P模拟赛】最短路(树形DP,树的直径)

    题面 给定一棵 n n n 个结点的无根树,每条边的边权均为 1 1 1 . 树上标记有 m m m 个互不相同的关键点,小 A \tt A A 会在这 m m m 个点中等概率随机地选择 k k k ...

  7. 2020牛客NOIP赛前集训营-提高组(第三场) C - 牛半仙的妹子Tree (树链剖分)

    昨天教练问我:你用树剖做这道题,怎么全部清空状态呢?    我:???不是懒标记就完了???    教练:树剖不是要建很多棵线段树吗,不止log个,你要一个一个清?    我:为什么要建很多棵线段树? ...

  8. 【MySQL】从入门到精通8-SQL数据库编程

    上期:[MySQL]从入门到精通7-设计多对多数据库 第零章:Mac用户看这里: mac终端写MySQL和windows基本相同,除了配置环境变量和启动有些许不同以外. 先配置环境变量,在终端输入vi ...

  9. 在 node 中使用 jquery ajax

    对于前端同学来说,ajax 请求应该不会陌生.jquery 真的ajax请求做了封装,可以通过下面的方式发送一个请求并获取相应结果: $.ajax({ url: "https://echo. ...

  10. KingbaseES 数据库Windows环境下注册失败分析

    关键字: KingbaseES.Java.Register.服务注册 一.安装前准备 1.1 软件环境要求 金仓数据库管理系统KingbaseES V8.0支持微软Windows 7.Windows ...