TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)
定义(What)
公共子表达式消除 就是如果表达式E的值已经计算的到了,并且自计算的到值后E的值就不再改变了,就说,表达式E在后续计算中是一个公共表达式。
简单说,该表达式上面已经执行过了,下面没必要再执行了
举个例子:
import tvm
from tvm import relay
from tvm.relay import transform
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def before():
x = relay.var("x", shape=(1, 16))
y1 = relay.nn.relu(x)
y2 = relay.nn.relu(x)
y1 = relay.add(y1, relay.const(1.0, "float32"))
y2 = relay.add(y2, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
return f
z = before()
print("before")
print(z)
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
print("after")
print(z)
通过print(z)打印公共子表达式消除前IRModule对象内容,如下:

消除之后的IRModule对象内容如下:

可以发现Relay图中的y2 = relay.nn.relu(x)节点被清除
因为表达式y2 = relay.nn.relu(x)在前一个表达式y1 = relay.nn.relu(x)中已经计算过了,只需要用前面计算过的表达式结果代替即可
作用 (Why)
意义就很简单了,为了避免重新计算表达式E,浪费计算资源,影响运行效率
怎么做(How)
上面的例子可看到,公共子表达式消除主要调用的是relay.transform.EliminateCommonSubexpr()接口,这个接口是对已注册的公共子表达式消除pass的封装。可见路径:python/tvm/relay/transform/transform.py
def EliminateCommonSubexpr(fskip=None):
"""Eliminate common subexpressions.
Parameters
----------
fskip: Callable
The callback function that decides whether an expression should be
skipped.
Returns
-------
ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
return _ffi_api.EliminateCommonSubexpr(fskip)
通过PackFunc机制,_ffi_api.EliminateCommonSubexpr接口最后会通过_LIB.TVMFuncGetGlobal函数获取到C++端注册的EliminateCommonSubexpr函数。
C++端EliminateCommonSubexpr注册代码如下:
Pass EliminateCommonSubexpr(PackedFunc fskip) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
.set_body_typed(EliminateCommonSubexpr);
上述代码,CreateFunctionPass()函数作用是生成FunctionPass对象,FunctionPass工作在Relay模块中的每一个Relay函数对象上。
- FunctionPass() 函数的第一个参数pass_func是TypedPackedFunc对象,真正的pass优化功能由该对象调用pass函数
EliminateCommonSubexpr()完成。 - 第二个参数是优化级别(当通过pass基础架构调用该pass时,会检查pass的优化级别,只有当该pass的优化级别不低于pass上下文配置中的优化级别时,才能启用执行该pass);
- 第三个参数是函数pass名称;
- 第四个参数是{} 中列出了公共子表达式消除pass依赖的其他pass,如InferType,因为需要类型信息,所以参数中列出了InferType pass名称
EliminateCommonSubexpr()函数的函数体是CommonSubexprEliminator()函数,它主要通过实现遍历Relay IR,完成Relay IR中的公共子表达式消除功能。
Relay IR遍历的C++实现类是ExprFunctor类的派生类,继承关系如下:

CommonSubexprEliminator()类通过重载Rewrite_()方法实现公共子表达式消除功能。该方法将处理过的表达式都存储在unordered_map变量expr _map_中。在每次通过ReWrite_方法处理当前表达式时,会先从expr_map_中查找是否有相同操作类型的已处理表达式,如果有,在判断当前表达式与已处理表达式的属性和参数是否相同,如果这些条件都满足,则返回满足条件的一处理表达式。
expr_map_定义如下:
std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
ReWrite_()方法(src/relay/transforms/eliminate_common_subexpr.cc)实现代码如下:
Expr Rewrite_(const CallNode* call, const Expr& post) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr new_expr = post;
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
StructuralEqual attrs_equal;
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
}
if (fskip_ != nullptr && fskip_(new_expr)) {
return new_expr;
}
auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const Expr& candidate_expr : it->second) {
if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!IsEquivalent(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
}
expr_map_[new_call->op].push_back(new_expr);
return new_expr;
}
在python端调用时,通过CreateFunctionPass()函数返回FunctionPass对象,然后通过该对象调用算子,如上述例子中opt_pass(mod)
它会调用Pass类的__call__方法来调用算子
@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
"""
@property
def info(self):
"""Get the pass meta."""
return _ffi_transform_api.Info(self)
def __call__(self, mod):
"""Execute the pass. Note that for sequential pass, the dependency among
different passes will be resolved in the backend.
Parameters
----------
mod : tvm.IRModule
The module that a certain optimization is performed on.
Returns
-------
mod : tvm.IRModule
The updated module after applying this pass.
"""
return _ffi_transform_api.RunPass(self, mod)
src/ir/transform.cc中ransform.RunPass注册代码如下:
TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});
此处的pass就是通过CreateFunctionPass()创建的对象,此处会调用pass中operator()重载,最终会调到FunctionPassNode类中的operator()方法,该实现会调到CreateFunctionPass()时保存的真正公共子表达式消除的代码的实现pass_func
总体,该算子优化还算是比较简单
respect~
致敬
TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)的更多相关文章
- 🏆【Java技术专区】「编译器专题」重塑认识Java编译器的执行过程(消除数组边界检查+公共子表达式)!
前提概要 Java的class字节码并不是机器语言,要想让机器能够执行,还需要把字节码翻译成机器指令.这个过程是Java虚拟机做的,这个过程也叫编译.是更深层次的编译. 在编译原理中,把源代码翻译成机 ...
- [Inside HotSpot] C1编译器优化:条件表达式消除
1. 条件传送指令 日常编程中有很多根据某个条件对变量赋不同值这样的模式,比如: int cmov(int num) { int result = 10; if(num<10){ result ...
- DB2 公共表表达式(WITH语句的使用)
----start 说起WITH 语句,除了那些第一次听说WITH语句的人,大部分人都觉得它是用来做递归查询的.其实那只是它的一个用途而已,它的本名正如我们标题写的那样,叫做:公共表表达式(Commo ...
- 公共表达式消除(UVa 12219)
紫书354页的题,将表达式树进行公共表达式消除,化为等价的图.因为需要判断某一个是否出现过,所以需要快速比较,采用哈希表的形式,将a~b与1~27一一对应,不采用0,因为0与0000是相同的,对于每一 ...
- TVM Pass IR如何使用
TVM Pass IR如何使用 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,并应用于TVM堆栈中IR的不同层. Relay / t ...
- 如何使用TVM Pass红外线
如何使用TVM Pass红外线 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,将其应用于TVM堆栈中IR的不同层. Relay / ...
- 动态规划求最长公共子序列(Longest Common Subsequence, LCS)
1. 问题描述 子串应该比较好理解,至于什么是子序列,这里给出一个例子:有两个母串 cnblogs belong 比如序列bo, bg, lg在母串cnblogs与belong中都出现过并且出现顺序与 ...
- 转:CTE(公共表表达式)——WITH子句
来自:<Microsoft SQL Server 2008技术内幕:T-SQL语言基础> 一.公共表表达式(CTE,Common Table Expression)是在SQL Server ...
- Leetcode之深度优先搜索(DFS)专题-1123. 最深叶节点的最近公共祖先(Lowest Common Ancestor of Deepest Leaves)
Leetcode之深度优先搜索(DFS)专题-1123. 最深叶节点的最近公共祖先(Lowest Common Ancestor of Deepest Leaves) 深度优先搜索的解题详细介绍,点击 ...
- TVM图优化与算子融合
TVM图优化与算子融合 计算图的定义 Computational graphs: a common way to represent programs in deep learning framewo ...
随机推荐
- Java后台获取微信小程序用户信息、openid
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 3 ...
- 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
前言 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单.该份教程旨在通过AI技术重构传统科研模式,提升研究效率与智能化水平. DeepSeek访问地址:ht ...
- ITSM运维管理整理总结
ITSM 和我们平常所说的软件管理最大的不同? 目标不是管理技术,主要任务是管理用户和客户的IT需求 2.人员.技术.流程[重要] 3.几大模块 模块名称 干什么 备注 服务台 1.对接客户的前方,负 ...
- 超详细移动端侧AI口罩识别实现与部署(含源码)
开发环境 数据标注:label studio :https://labelstud.io/ 模型训练:tensorflow 附完整的训练源码和数据 部署开发:Android studio + tens ...
- Java 线程安全的集合
Vector ArrayList 的线程安全版本,对所有的修改方法都进行了 synchronized 同步处理.适用于多线程环境下对数据一致性要求高,且读写操作相对比较均衡,不需要很高并发性能的场景. ...
- Windows 提权-服务_弱注册表权限
本文通过 Google 翻译 Weak Registry Key Permissions – Windows Privilege Escalation 这篇文章所产生,本人仅是对机器翻译中部分表达别扭 ...
- Browser-use 详细介绍&使用文档
Browser-use 详细介绍&使用文档 一.概述 Browser-use 是一个旨在将 AI "智能体"(Agents)与真实浏览器进行交互的 Python 库,可以轻 ...
- api使用流程、Scanner键盘录入字符串
1.api介绍 2.api使用流程 API帮助文档的使用流程 在索引位置搜索自己要查看的类 看包 目的: 是不是java.lang包(核心包), 不需要编写导包代码(import) - 不是java. ...
- 【Linux】U-Boot 加载并启动 Linux 系统程序
U-Boot 加载并启动 Linux 系统程序 零.介绍 最近在玩一些嵌入式的开发板,在引导操作系统时需要用到U-Boot,故此研究一下. U-Boot(Universal Bootloader)是一 ...
- AI团队比单打独斗强!CrewAI多智能体协作系统开发踩坑全解析
AI团队比单打独斗强!CrewAI多智能体协作系统开发踩坑全解析 阅读时间: 5分钟 | 字数: 1500+ "你是否曾为单个大模型难以解决复杂专业问题而苦恼?是否想过,如果能像组建专业团队 ...