TVM Pass优化 -- InferType 类型推导
定义(What)
InferType,类型推断,顾名思义,给表达式进行类型的推断
直接上代码
import tvm
from tvm import relay
import numpy as np
def get_demo_mod():
a = relay.var("a", shape=(2, 3, 10), dtype="float32")
b = relay.var("b", shape=(1, 10), dtype="float32")
c = relay.add(a, b)
func = relay.Function([a, b], c)
mod = tvm.IRModule.from_expr(func)
return mod
mod = get_demo_mod()
print("------before InferType------")
try:
print(mod["main"].body.checked_type)
except Exception:
print("can't get checked_type")
print("------after InferType------")
mod = relay.transform.InferType()(mod)
print(mod["main"].body.checked_type)
执行结果如下:
作用 (Why)
推断表达式的类型及输入输出尺寸
另:在 Relay 优化过程中, 每个 pass 都可以修改/添加/删除 op, 所以每个 pass 之后都需要重新 InferType
如,TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)对公共子表达式消除一节中FunctionPass()
第四个参数就是InferType进行类型推断
怎么做(How)
这块代码主要在src/relay/transforms/type_infer.cc文件中,具体实现如下:
Pass InferType() {
auto pass_info = PassInfo(0, "InferType", {});
return tvm::transform::CreateModulePass(
[=](IRModule mod, const PassContext& pass_ctx) {
...
AddGlobalTypes(mod);
VLOG(1) << "AddGlobalTypes'" << PrettyPrint(mod);
std::vector<std::pair<GlobalVar, Function>> updates;
for (const auto& it : updated_mod->functions) {
if (auto func = it.second.as<Function>()) {
auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value());
VLOG(1) << "it.first'" << PrettyPrint(it.first) << "it.second"<< PrettyPrint(it.second);
auto updated_func = inferencer.Infer(it.first, func.value());
VLOG(1) << "updated_func'" << PrettyPrint(updated_func);
...
it.first->checked_type_ = updated_func->checked_type();
if (!WellFormed(updated_func, pass_ctx->diag_ctx)) {
LOG(FATAL) << "The type checked intermediate representation is malformed";
}
auto free_tvars = FreeTypeVars(updated_func, mod);
ICHECK(free_tvars.size() == 0)
<< "Found unbound type variables in " << updated_func << ": " << free_tvars;
EnsureCheckedType(updated_func);
updates.push_back({it.first, Downcast<Function>(updated_func)});
}
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
return updated_mod;
},
0, "InferType", {});
}
TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); });
和公共子表达式消除的实现可发现,该算子调用的是CreateModulePass
,因此它是一个模块级的优化,
模块级优化用于实现过程间优化和分析,模块级优化pass工作在tvm.IRModule对象上,将整个程序作为处理单元,几乎可以对程序执行任何操作。
其中,AddGlobalTypes
给mod添加全局参数,为后续的参数推断做准备,
真正进行推断的是TypeInferencer
类的Infer()方法
,实现如下:
Expr TypeInferencer::Infer(GlobalVar var, Function function) {
...
// Step 1: Populate the constraints.
GetType(function);
// Step 2: Solve the constraints.
Solve();
// Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(function);
...
}
return resolved_expr;
}
第一步,填充约束
Type GetType(const Expr& expr) {
auto it = type_map_.find(expr);
if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second.checked_type;
}
Type ret = this->VisitExpr(expr);
ICHECK(ret.defined()) << "expression:" << std::endl << PrettyPrint(expr);
KindCheck(ret, mod_, this->diag_ctx);
ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret;
return ret;
}
会先从type_map_
map表中查找该Expr,第一次执行,如果type_map_
中未找到该expr,便会通过VisitExpr()
方法在该map表中添加,具体实现如下:
void VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
Type ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
}
}
bool CheckVisited(const Expr& expr) {
if (memo_.count(expr)) {
return true;
} else {
return false;
}
}
Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); }
Type VisitExpr(const Expr& expr) final {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}
其中fcheck_visited()
匿名函数通过调用VisitLeaf方法中的DispatchVisitExpr方法,该函数会调用到ExprFunctor类
中构建的包含各种类型的虚表中,根据类型调用对应的VisitExpr_
方法,如CallNode类型的参数,代码如下:
Type VisitExpr_(const CallNode* call) final {
Array<Type> arg_types;
for (Expr arg : call->args) {
arg_types.push_back(GetType(arg));
}
if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype =
PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs, call->span);
if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return rtype;
}
}
其中,AddTypeArgs()
会向type_map_表中插入该expr
void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
} else {
ICHECK(!type_info->second.type_args.defined());
type_info->second.type_args = type_args;
}
}
第二步,解决约束
bool TypeSolver::Solve() {
while (!update_queue_.empty()) {
RelationNode* rnode = update_queue_.front();
const auto& rel = rnode->rel;
update_queue_.pop();
ICHECK(!rnode->resolved);
// update the relation with given evidence.
Array<Type> args;
for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) {
args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
ICHECK_LE(args.size(), rel->args.size());
}
// We need to set this in order to understand where unification
// errors generated by the error reporting are coming from.
reporter_->SetSpan(rnode->span);
try {
// Call the Type Relation's function.
bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
if (resolved) {
++num_resolved_rels_;
}
rnode->resolved = resolved;
} catch (const CompileError& err) {
this->Emit(Diagnostic::Error(rnode->span) << err.what());
rnode->resolved = false;
}
// Mark inqueue as false after the function call
// so that rnode itself won't get enqueued again.
rnode->inqueue = false;
}
// This criterion is not necessarily right for all the possible cases
// TODO(tqchen): We should also count the number of in-complete types.
return num_resolved_rels_ == rel_nodes_.size();
}
通过调用 Solve() 方法,我们求解填充好的类型约束。解决约束的过程使用了类型约束求解器(constraint solver)来尝试找到满足约束条件的类型赋值方案。
第三步,
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {}
Expr MixedModeMutator::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}
使用 Resolver 类
的实例来将解析后的类型信息附加到已解析的表达式的checked_type
字段上。Resolver 类是负责类型解析和处理的工具类。它通过访问表达式的结构,并使用之前求解出的类型信息来确定每个表达式的准确类型。
respect~
TVM Pass优化 -- InferType 类型推导的更多相关文章
- TVM Pass IR如何使用
TVM Pass IR如何使用 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,并应用于TVM堆栈中IR的不同层. Relay / t ...
- 如何使用TVM Pass红外线
如何使用TVM Pass红外线 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,将其应用于TVM堆栈中IR的不同层. Relay / ...
- 类型推导:函数模板与auto
1.从函数模板谈起 函数模板的类型推导机制是在c++98时代就有的,auto的类型推导机制与其基本一致,所以先理解函数模板类型推导. 函数模板可以用如下代码框架表示: #template<typ ...
- 第1课 类型推导(1)_auto关键字
1. auto关键字 (1)auto的作用是让编译器自动推断变量的类型,而不需要显式指定类型.这种隐式类型的推导发生在编译期. (2)auto并不能代表实际的类型声明,只是一个类型声明的“占位符” ...
- 模板类型推导、auto推导
effective modern c++ 果然是神书,干货满满,简单记录下. item1 模板推倒 典型的模板函数 temlate<class T> void fn(ParamType p ...
- C++11(列表初始化+变量类型推导+类型转换+左右值概念、引用+完美转发和万能应用+定位new+可变参数模板+emplace接口)
列表初始化 用法 在C++98中,{}只能够对数组元素进行统一的列表初始化,但是对应自定义类型,无法使用{}进行初始化,如下所示: // 数组类型 int arr1[] = { 1,2,3,4 }; ...
- Java 8 新特性之泛型的类型推导
1. 泛型究竟是什么? 在讨论类型推导(type inference)之前,必须回顾一下什么是泛型(Generic).泛型是Java SE 1.5的新特性,泛型的本质是参数化类型,也就是说所操作的数据 ...
- C++11 - 类型推导auto关键字
在C++11中,auto关键字被作为类型自动类型推导关键字 (1)基本用法 C++98:类型 变量名 = 初值; int i = 10; C++11:auto 变量名 = 初值; auto i ...
- 图说函数模板右值引用参数(T&&)类型推导规则(C++11)
见下图: 规律总结: 只要我们传递一个基本类型是A④的左值,那么,传递后,T的类型就是A&,形参在函数体中的类型就是A&. 只要我们传递一个基本类型是A的右值,那么,传递后,T的类型就 ...
- C++11 图说VS2013下的引用叠加规则和模板参数类型推导规则
背景: 最近在学习C++STL,出于偶然,在C++Reference上看到了vector下的emplace_back函数,不想由此引发了一系列的“探索”,于是就有了现在这篇博文. 前言: ...
随机推荐
- element vue 动态单选_VUE 动态构建混合数据Treeselect选择树,同时解决巨树问题
今天在项目中需要通过行政区域选择,然后选择该行政区域下面的景区,也就是要构建行政区划.景区两表数据表的树.全国的行政区域到县已经3500多了,再加上景区会有几万个点,这棵选择树不论是在后台还是在前台构 ...
- Qt个人项目总结 —— MySQL数据库查询与断言
3.Qt项目总结--数据库查询断言问题 问题: 当我使用MySQL数据库的查询操作时, 如果查询的数据在数据库中不存在,那么Qt会直接被干崩溃 但是?为什么呢?不应该是返回if语句中的结果吗,为什么会 ...
- [tldr]GO使用正则表达式
简述如何使用GO调用正则表达式 是否符合条件 使用MatchString方法实现 _, err := regexp.MatchString(regex, str) 提取内容 Compile 第一步需要 ...
- 前端解析excel表格实现
1. 背景:在做react项目时,遇到一个解析excel的需求变更,把从原来后端解析变更为前端解析. 1.1 由于后端解析excel文件有安全隐患,因为项目中后端不允许上传文件,当然后端解析对前端来说 ...
- Golang 入门 : 类型系统介绍
Go语言类型系统 从计算机底层看,所有的数据都是由比特组成,但计算机一般操作的是固定大小的数,如整数.浮点数.比特数组.内存地址等.但是直接操控底层计算机指令进行编程是非常繁琐和容易出错的,所以Go语 ...
- xshell连接Win10下子系统Unbuntu
自带的ssh server不好用,需要先卸载再安装. 1. 卸载 ssh server sudo apt-get remove openssh-server 2. 安装 ssh server sudo ...
- Python实现PDF转换文件格式
最近工作中经常遇到收到其他人提供的pdf文档,想要编辑修改下或者复制部分内容比较困难,想通过现有的pdf工具软件转换文档格式,基本都要充钱,为了免费实现pdf转换工具,网上查了下相关技术方案,整理了下 ...
- C# 13 中的新增功能实操
前言 今天大姚带领大家一起来看看 C# 13 中的新增几大功能,并了解其功能特性和实际应用场景. 前提准备 要体验 C# 13 新增的功能可以使用最新的 Visual Studio 2022 版本或 ...
- 【单片机】滑稽AT89C52表情实现
[单片机]滑稽AT89C52表情实现 零.原因 在群里看到了这样一个表情: 这是用51做的,刚好开发板上有8个小灯,想实现一下. 一.代码 新建工程,写入如下代码: #include<reg52 ...
- Java+Appium+Junit实现app自动化demo
1.新建maven工程和引入库 步骤参考https://www.cnblogs.com/wanyuan/p/16408758.html 2.编写代码 代码如下: import org.junit.Af ...