定义(What)

公共子表达式消除 就是如果表达式E的值已经计算的到了,并且自计算的到值后E的值就不再改变了,就说,表达式E在后续计算中是一个公共表达式。

简单说,该表达式上面已经执行过了,下面没必要再执行了

举个例子:

import tvm
from tvm import relay
from tvm.relay import transform def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body def before():
x = relay.var("x", shape=(1, 16))
y1 = relay.nn.relu(x)
y2 = relay.nn.relu(x)
y1 = relay.add(y1, relay.const(1.0, "float32"))
y2 = relay.add(y2, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
return f z = before()
print("before")
print(z)
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
print("after")
print(z)

通过print(z)打印公共子表达式消除前IRModule对象内容,如下:



消除之后的IRModule对象内容如下:



可以发现Relay图中的y2 = relay.nn.relu(x)节点被清除

因为表达式y2 = relay.nn.relu(x)在前一个表达式y1 = relay.nn.relu(x)中已经计算过了,只需要用前面计算过的表达式结果代替即可

作用 (Why)

意义就很简单了,为了避免重新计算表达式E,浪费计算资源,影响运行效率

怎么做(How)

上面的例子可看到,公共子表达式消除主要调用的是relay.transform.EliminateCommonSubexpr()接口,这个接口是对已注册的公共子表达式消除pass的封装。可见路径:python/tvm/relay/transform/transform.py

def EliminateCommonSubexpr(fskip=None):
"""Eliminate common subexpressions. Parameters
----------
fskip: Callable
The callback function that decides whether an expression should be
skipped. Returns
-------
ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
return _ffi_api.EliminateCommonSubexpr(fskip)

通过PackFunc机制,_ffi_api.EliminateCommonSubexpr接口最后会通过_LIB.TVMFuncGetGlobal函数获取到C++端注册的EliminateCommonSubexpr函数。

C++端EliminateCommonSubexpr注册代码如下:

Pass EliminateCommonSubexpr(PackedFunc fskip) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
} TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
.set_body_typed(EliminateCommonSubexpr);

上述代码,CreateFunctionPass()函数作用是生成FunctionPass对象,FunctionPass工作在Relay模块中的每一个Relay函数对象上。

  • FunctionPass() 函数的第一个参数pass_func是TypedPackedFunc对象,真正的pass优化功能由该对象调用pass函数EliminateCommonSubexpr()完成。
  • 第二个参数是优化级别(当通过pass基础架构调用该pass时,会检查pass的优化级别,只有当该pass的优化级别不低于pass上下文配置中的优化级别时,才能启用执行该pass);
  • 第三个参数是函数pass名称;
  • 第四个参数是{} 中列出了公共子表达式消除pass依赖的其他pass,如InferType,因为需要类型信息,所以参数中列出了InferType pass名称

EliminateCommonSubexpr()函数的函数体是CommonSubexprEliminator()函数,它主要通过实现遍历Relay IR,完成Relay IR中的公共子表达式消除功能。

Relay IR遍历的C++实现类是ExprFunctor类的派生类,继承关系如下:

CommonSubexprEliminator()类通过重载Rewrite_()方法实现公共子表达式消除功能。该方法将处理过的表达式都存储在unordered_map变量expr _map_中。在每次通过ReWrite_方法处理当前表达式时,会先从expr_map_中查找是否有相同操作类型的已处理表达式,如果有,在判断当前表达式与已处理表达式的属性和参数是否相同,如果这些条件都满足,则返回满足条件的一处理表达式。

expr_map_定义如下:

  std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;

ReWrite_()方法(src/relay/transforms/eliminate_common_subexpr.cc)实现代码如下:

Expr Rewrite_(const CallNode* call, const Expr& post) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr new_expr = post;
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
StructuralEqual attrs_equal; if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
}
if (fskip_ != nullptr && fskip_(new_expr)) {
return new_expr;
} auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const Expr& candidate_expr : it->second) {
if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!IsEquivalent(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
}
expr_map_[new_call->op].push_back(new_expr);
return new_expr;
}

在python端调用时,通过CreateFunctionPass()函数返回FunctionPass对象,然后通过该对象调用算子,如上述例子中opt_pass(mod)

它会调用Pass类的__call__方法来调用算子

@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
""" @property
def info(self):
"""Get the pass meta."""
return _ffi_transform_api.Info(self) def __call__(self, mod):
"""Execute the pass. Note that for sequential pass, the dependency among
different passes will be resolved in the backend. Parameters
----------
mod : tvm.IRModule
The module that a certain optimization is performed on. Returns
-------
mod : tvm.IRModule
The updated module after applying this pass.
"""
return _ffi_transform_api.RunPass(self, mod)

src/ir/transform.cc中ransform.RunPass注册代码如下:

TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});

此处的pass就是通过CreateFunctionPass()创建的对象,此处会调用pass中operator()重载,最终会调到FunctionPassNode类中的operator()方法,该实现会调到CreateFunctionPass()时保存的真正公共子表达式消除的代码的实现pass_func

总体,该算子优化还算是比较简单

respect~

致敬

TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)的更多相关文章

  1. 🏆【Java技术专区】「编译器专题」重塑认识Java编译器的执行过程(消除数组边界检查+公共子表达式)!

    前提概要 Java的class字节码并不是机器语言,要想让机器能够执行,还需要把字节码翻译成机器指令.这个过程是Java虚拟机做的,这个过程也叫编译.是更深层次的编译. 在编译原理中,把源代码翻译成机 ...

  2. [Inside HotSpot] C1编译器优化:条件表达式消除

    1. 条件传送指令 日常编程中有很多根据某个条件对变量赋不同值这样的模式,比如: int cmov(int num) { int result = 10; if(num<10){ result ...

  3. DB2 公共表表达式(WITH语句的使用)

    ----start 说起WITH 语句,除了那些第一次听说WITH语句的人,大部分人都觉得它是用来做递归查询的.其实那只是它的一个用途而已,它的本名正如我们标题写的那样,叫做:公共表表达式(Commo ...

  4. 公共表达式消除(UVa 12219)

    紫书354页的题,将表达式树进行公共表达式消除,化为等价的图.因为需要判断某一个是否出现过,所以需要快速比较,采用哈希表的形式,将a~b与1~27一一对应,不采用0,因为0与0000是相同的,对于每一 ...

  5. TVM Pass IR如何使用

    TVM Pass IR如何使用 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,并应用于TVM堆栈中IR的不同层. Relay / t ...

  6. 如何使用TVM Pass红外线

    如何使用TVM Pass红外线 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,将其应用于TVM堆栈中IR的不同层. Relay / ...

  7. 动态规划求最长公共子序列(Longest Common Subsequence, LCS)

    1. 问题描述 子串应该比较好理解,至于什么是子序列,这里给出一个例子:有两个母串 cnblogs belong 比如序列bo, bg, lg在母串cnblogs与belong中都出现过并且出现顺序与 ...

  8. 转:CTE(公共表表达式)——WITH子句

    来自:<Microsoft SQL Server 2008技术内幕:T-SQL语言基础> 一.公共表表达式(CTE,Common Table Expression)是在SQL Server ...

  9. Leetcode之深度优先搜索(DFS)专题-1123. 最深叶节点的最近公共祖先(Lowest Common Ancestor of Deepest Leaves)

    Leetcode之深度优先搜索(DFS)专题-1123. 最深叶节点的最近公共祖先(Lowest Common Ancestor of Deepest Leaves) 深度优先搜索的解题详细介绍,点击 ...

  10. TVM图优化与算子融合

    TVM图优化与算子融合 计算图的定义 Computational graphs: a common way to represent programs in deep learning framewo ...

随机推荐

  1. Git的快速使用

    Git的快速使用 git提交到gitee 1.初始化本地仓库 git init 2.拉取远程仓库代码 git clone https://gitee.com/sword-level_0/mount-t ...

  2. 解决easyexcel合并单元格数组求和重复问题

    背景 EasyExcel(根据条件动态合并单元格的重复数据))_Violet-CSDN博客_easyexcel动态合并单元格现有的订单导出是使用的easyExcel完成的.对于相同单元格的合并是自定义 ...

  3. Hadoop - [01] 概述

    Hadoop官网:https://hadoop.apache.org/ Hadoop下载:https://archive.apache.org/dist/hadoop/common/ 一.Hadoop ...

  4. Docker - 在docker中部署Nginx

    1.docker search 查找ngix 2.docker pull下载镜像 3.查看镜像列表 4.docker run启动容器 5.测试nginx容器是否启动成功 1.docker search ...

  5. 【渗透测试】Vulnhub DarkHole

    渗透环境 攻击机:   IP: 192.168.216.129(Kali) 靶机:     IP:192.168.216.130 靶机下载地址:https://www.vulnhub.com/entr ...

  6. 机器学习 | 强化学习(3) | 无模型预测(Model-Free Prediction)

    无模型预测(Model-Free Prediction) 无模型预测概论 上一节课<通过DP求解>可以解决一个已知的马尔科夫决策过程 本节课 实践无模型预测 解决或者估计一个未知马尔科夫决 ...

  7. MySQL超大表删除数据过程

    背景 笔者在公司负责公司的OpenAPI应用,估产生了调用审计的需求.对于存储这些AccessLog,虽然业界有很合适的架构和理论,奈何我司已成本优先,且作为toB的项目,调用量并不算特别大,每天也就 ...

  8. go ceph s3文件管理

    导入依赖 go get gopkg.in/amz.v1/aws go get gopkg.in/amz.v1/s3 创建用户 在初始化连接之前,我们需要创建一个用户得到accessKey和secret ...

  9. 在Ubuntu上安装php7.2、php7.3、php7.4

    目录 开始之前 在Ubuntu 18.04或16.04上安装PHP 7.4 更新Ubuntu 添加PHP存储库 安装PHP 7.4 在Ubuntu 16.04上安装PHP 7.2 更新Ubuntu 添 ...

  10. 编写你的第一个 Django 应用程序,第6部分

    本教程从教程 5 停止的地方开始.我们已经构建了一个经过测试的网络投票应用程序,现在我们将添加一个样式表和一个图像. 除了服务器生成的 HTML 之外,Web 应用程序通常需要提供呈现完整网页所需的其 ...