定义

算子融合 就是将多个计算单元合并到一个计算单元里完成计算,减少中间数据读写内存的操作,从而节省计算时间。

TVM中将算子融合分为四种:

  • kElemWise:两个tensor之间按照元素逐个操作的算子,实际上所有的四则运算都是这种类型
  • kBroadcast:带有广播操作的算子
  • kInjective:输入和输出之间具有一对一映射关系的算子,如add/sqrt/exp等操作算子(operator)
  • kCommReduce:多到少的映射,输入到输出就有降维性质,如sum/max/min等操作算子
  • kOutEWiseFusable:这是计算比较复杂的算子,输出可与kElemWise进行fuse的算子,如conv2d/bn/relu等算子
  • kTuple:操作元祖的算子,如TupleNode,TupleGetItemNode等;
  • kOpaque:无法进行融合的算子,如sort

根据TVM论文,TVM提供了 三种融合规则:



从融合算子的内部视角看,这种融合实际上是数据计算pipeline化,即两次计算中间数据不再经历store-load过程,而是直接给到下一个计算单元完成计算。

举例子:

import tvm
from tvm import te
import tvm.relay as relay
import numpy as np
from tvm.relay.testing import run_opt_pass def get_relay_ir():
shape = (1, 3, 14, 14)
c_data = np.ones(shape).astype('float32')
c = relay.const(c_data) weight = relay.var('weight', shape=(3, 3, 3, 3))
x = relay.var('x', relay.TensorType((1, 3, 16, 16), 'float32'))
conv = relay.nn.conv2d(x, weight)
y = relay.add(conv, c)
act = relay.nn.relu(y) mul = relay.multiply(conv, relay.const(0.5, 'float32'))
z = act + mul
return relay.Function([x, weight], z) f = get_relay_ir()
mod = tvm.IRModule.from_expr(f)
print('src module:')
print(mod) mod = run_opt_pass(f, relay.transform.FuseOps(fuse_opt_level=4))
print('fuse_ops:')
print(mod)

运行结果:

def @main(%x: Tensor[(1, 3, 16, 16), float32], %weight: Tensor[(3, 3, 3, 3), float32]) {
%0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
%1 = add(%0, meta[relay.Constant][0]);
%2 = nn.relu(%1);
%3 = multiply(%0, 0.5f);
add(%2, %3)
} fuse_ops:
fn (%x: Tensor[(1, 3, 16, 16), float32], %weight: Tensor[(3, 3, 3, 3), float32]) -> Tensor[(1, 3, 14, 14), float32] {
%4 = fn (%p0: Tensor[(1, 3, 16, 16), float32], %p1: Tensor[(3, 3, 3, 3), float32], %p2: Tensor[(1, 3, 14, 14), float32], Primitive=1) -> Tensor[(1, 3, 14, 14), float32] {
%0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]);
%1 = add(%0, %p2);
%2 = nn.relu(%1);
%3 = multiply(%0, 0.5f);
add(%2, %3)
}
%4(%x, %weight, meta[relay.Constant][0])
}

根据运行结果,可发现算子融合pass后, conv2d、add、relu和 multiply算子被融合成一个算子,在TVM中为CallNode

作用

算子融合的目的最终是要解决 AI 处理器的内存墙、并行墙的问题,提升 Tensor 数据的访存局部性。

怎么实现

算子融合pass的python入口在transform.py:python/tvm/relay/transform/transform.py 中

def FuseOps(fuse_opt_level=-1):
"""Fuse operators in an expr to a larger operator according to some rules. Parameters
----------
fuse_opt_level : int
The level of fuse optimization. -1 indicates that the level will be
inferred from pass context. Returns
-------
ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level)

TVM通过 packed_func ffi 机制实现了 python 和 c++ 之间的相互调用,其 c++ 后端代码在fuse_ops.cc, 在src/relay/transforms/fuse_ops.cc路径下:

Pass FuseOps(int fuse_opt_level) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
bool link_params = false;
Executor executor =
m->GetAttr<Executor>(tvm::attr::kExecutor).value_or(NullValue<Executor>());
link_params = executor.defined()
? executor->attrs.GetAttr<Bool>("link-params").value_or(Bool(link_params))
: link_params;
link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value();
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
auto target = Target::Current();
size_t max_function_args =
(target.defined())
? target->GetAttr<Integer>("max_function_args", Integer(0)).value().IntValue()
: 0;
return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value().IntValue(),
max_function_args, link_params, m));
};
return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"});
} TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps);

可发现,该pass为Function级别的pass,此处目前只关注 fuse_opt_level优化级别选项即可,可通过passContext进行设置,其余参数暂未用到,使用其默认值即可。

TVM算子融合流程

TVM中算子融合流程分为三步:

  1. 遍历relay树,建立DAG用于后支配树分析
  2. 构建后支配树,能够快速求取任意节点的后支配点
  3. 根据当前节点的后支配点信息,在两节点路径之间进行融合算法

主体代码如下:

// Run the transform
Expr Transform(const Expr& body) {
return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_);
} // Run the transform
Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
// setup the group map.
auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
.Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
}
// The following line can be used for debug.
// this->DebugDumpGroup(body);
return this->Mutate(body);
}

构建DAG

构建DAG主要由以下代码完成:

auto graph = IndexedForwardGraphCreator::Create(&arena_, body);

其中,arena_ 为内存管理模块, body 为relay的树IR, 此处是一个FunctionNode

TVM中使用IndexedForwardGraphCreator结构来保存DAG图,定义如下:

class IndexedForwardGraph {
public:
struct Node;
/*!
* The forward edge in the dataflow graph.
*/
struct Edge {
/*! \brief The corresponding node */
Node* node{nullptr};
/*! \brief The respective pattern of this op */
OpPatternKind pattern{kOpaque};
};
/*! \brief A node in the graph. */
struct Node {
/*! \brief weak reference to the corresponding edge. */
const tvm::Object* ref{nullptr};
/*! \brief The index of the node in topological order. */
size_t index{0};
/*! \brief Whether this node is referenced by external source */
bool extern_ref{false};
/*! \brief The general pattern in the node */
OpPatternKind pattern{kOpaque};
/*! \brief The outputs of the node. */
LinkedList<Edge> outputs;
};
/*! \brief The node map that maps node to graph */
std::unordered_map<const tvm::Object*, Node*> node_map;
/*! \brief All the nodes in post DFS order */
std::vector<Node*> post_dfs_order;
};

Node表示节点,存储了引用对象reg, 拓扑序index, 是否被引用extern_ref, 算子类型pattern以及节点输出边outputs这些信息

Edge表示边,存储管理的node节点以及算子的pattern

node_map存储了对象和节点的映射关系

post_dfs_order 保存了所有节点的后序遍历节点

该类主要通过IndexedForwardGraphCreator creator对 Relay IR转换为 Graph node 的 IR 数据结构的转换。

IndexedForwardGraphCreator 继承 ExprVisitor,主要对 FunctionNodeCallNodeConstantNode等节点的遍历进行重写

该pass用户传进去的是一个 FunctionNode,因此首先进去 FunctionNode 的处理逻辑:

  // Post order tree
void VisitExpr_(const FunctionNode* op) final {
// Skip the function that should be handled by external codegen.
if (op->GetAttr<String>(attr::kCompiler).defined()) return; for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
}

其逻辑先对参数和函数体进行 Update,之后进入父类的VisitExpr_方法进行递归遍历。

  • Update过程即为Graph中创建或更新Node的操作,如果有parent参数,需要创建Edge,其代码如下:
 // Update the message stored at the node.
void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
const tvm::Object* key = node.get();
IndexedForwardGraph::Node* current;
auto it = graph_.node_map.find(key);
if (it != graph_.node_map.end()) {
current = it->second;
} else {
current = arena_->make<IndexedForwardGraph::Node>();
graph_.node_map[key] = current;
}
if (parent != nullptr) {
auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
link->value.node = parent;
link->value.pattern = pattern;
current->outputs.Push(link);
} else {
current->extern_ref = true;
}
}
  • 父类的 VisitExpr_方法首先访问 FunctionNode的参数:%x%weight, 更新节点信息,%x的拓扑序是0, %weight的拓扑序为1, 且更新了graph的post-dfs顺序:
void ExprVisitor::VisitExpr_(const FunctionNode* op) {
this->VisitSpan(op->span);
for (auto param : op->params) {
this->VisitExpr(param);
}
...
}
  void VisitExpr_(const VarNode* op) final { this->AddNode(op); }

    void AddNode(const tvm::Object* key) {
auto it = graph_.node_map.find(key);
ICHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
IndexedForwardGraph::Node* node = it->second;
ICHECK(node->ref == nullptr);
node->ref = key;
node->index = graph_.post_dfs_order.size();
graph_.post_dfs_order.push_back(node);
}

接下来是访问FunctionNode的函数体body,它是个CallNode 节点,所示:add(%2, %3)

void ExprVisitor::VisitExpr_(const FunctionNode* op) {
...
this->VisitExpr(op->body);
}

因此会进入如下代码中:

  void VisitExpr_(const CallNode* call) final {
ICHECK(graph_.node_map.count(call));
IndexedForwardGraph::Node* node = graph_.node_map.at(call);
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
// Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
// annotated pattern.
//
// If the pattern is not annotated we will default to opaque.
//
// Finally if the operator position is not a call node we will
// need to call Update, as it may be an arbitrary expression.
OpPatternKind op_pattern = kOpaque;
if (auto optional = call->op.as<Op>()) {
auto op = optional.value();
if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
// output of a shape func can't be fed to a data-dependent shape func
op_pattern = kOpaque;
} else {
op_pattern = static_cast<OpPatternKind>(fpattern[op]);
}
} else {
this->Update(call->op, node, kOpaque);
} node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
...
...
}

访问CallNode后,会先从全局注册中找到op算子类型,如上述add算子,它为kBroadcast类型,并通过Update接口将Add算子添加到graph中

  void VisitExpr_(const CallNode* call) final {
...
...
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the analysis back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
// specifically check if result type is the same as arguments type
OpPatternKind edge_pattern = op_pattern;
if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
attr_equal_(rtype->shape, arg_type->shape)) {
edge_pattern = kElemWise;
}
this->Update(call->args[i], node, edge_pattern);
}
ExprVisitor::VisitExpr_(call);
this->AddNode(call);
}

接下来处理输入的args,此处会判断如果输入args的shape和返回值shape一致,则将edge类型从kBroadcast转换为kElemWise,之后更新到arg节点,建立arg到CallNode(Call(Add, ...))的边,如下图第一阶段处理所示;

  • 接下来继续进入ExprVisitor::VisitExpr_(call)的CallNode节点处理函数中,依次处理参数(%2, %3)、body,处理参数%2,如图第二阶段;
  • 继续递归处理(post-dfs),如下图第三阶段所示;
  • %2分支更新完,如下图第四阶段;
  • 接下来更新%3分支,直到图被更新完成,如下图第五阶段。

可对照该例子更好理解:

注意:以上遍历流程是按照post-dfs顺序遍历的,每次遍历完成一个节点的所有输入后,还会进行AddNode操作来更新拓扑序,如在图中标明拓扑序,至此图被更新完成。

构建后支配树

构建后支配树的目的主要是为了能快速找出任一节点的直接后支配点,代码如下:

  Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
// setup the group map.
auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
.Partition(graph);
....

Partition实现如下:

std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
const IndexedForwardGraph& graph) {
this->InitGroups(graph);
if (opt_level_ == 0) return std::move(groups_);
// get post dominator tree
auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
// run fusion algorithm.
...
}

当opt_level为0时,不做任何的融合。

重点关注下支配树相关的处理:

  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);

DominatorTree的数据结构如下:

/*!
* \brief Dominator tree that represent domination or
* post domination relation of the node.
*/
class DominatorTree {
public:
/*!
* \brief A node in the dominator tree.
*/
struct Node {
/*! \brief The node in the tree */
IndexedForwardGraph::Node* gnode{nullptr};
/*! \brief parent of the tree */
Node* parent{nullptr};
/*! \brief current depth*/
int depth{0};
/*! \brief aggregated pattern to parent */
OpPatternKind pattern{kOpaque};
};
// index -> node.
std::vector<Node*> nodes;
...
};

此处定义的支配树包括了index到节点的映射,节点包括以下字段,填充这些数据结构即完成了Graph -> DominatorTree数据结构的转换

  • gnode:相对Graph的节点引用
  • parent:父节点
  • depth:深度,方便计算LCA
  • pattern:算子类型

现在来看下后支配树的计算过程:

DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
DominatorTree tree;
tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
// reverse topo order
for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
size_t index = i - 1;
tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
}
return tree;
}

根据逆向拓扑序依次处理graph中的节点,因此依次处理上图中的拓扑序8->7->...->0的节点,GetNode的实现逻辑如下:


DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
IndexedForwardGraph::Node* gnode) {
Node* tnode = arena->make<Node>();
tnode->gnode = gnode;
if (gnode->extern_ref) {
tnode->depth = 1;
tnode->parent = nullptr;
tnode->pattern = kOpaque;
} else {
// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;
}
return tnode;
}

其中拓扑序0,1,8节点的extern_ref字段为true,其余均为false

在处理其余节点时,会进入LeastCommonAncestor逻辑,其中,CombinePattern的处理逻辑:返回两个算子类型中更不容易融合的类型

  // Combine pattern together.
inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
if (lhs > rhs) return lhs;
return rhs;
}

LeastCommonAncestor实现如下,此处以节点2为例,跟踪代码执行路径:

DominatorTree::Node* DominatorTree::LeastCommonAncestor(
const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
auto link = input_nodes.head;
if (link == nullptr) {
return nullptr;
}
auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
size_t oindex = edge.node->index;
ICHECK_LT(oindex, nodes.size());
Node* onode = nodes[oindex];
ICHECK(onode != nullptr);
return onode;
};
Node* parent = get_node(link->value);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
link = link->next;
for (; link != nullptr; link = link->next) {
parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
*edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
}
return parent;
}

  • 节点2有两条输出边,第一条边指向节点4,第二条边指向节点7,首先处理第一条边,处理完成后,parent为节点4,edge_pattern为kEleWise;
  • 接下来处理第二条边,进入以下代码逻辑,根据depth信息找到两节点的最近公共父节点LCA,在此过程中不断更新edge_pattern;
  • 处理完成后,parent为节点8,edge_pattern为kEleWise。
DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
OpPatternKind* edge_pattern) {
while (lhs != rhs) {
if (lhs == nullptr) return nullptr;
if (rhs == nullptr) return nullptr;
if (lhs->depth < rhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
rhs = rhs->parent;
} else if (rhs->depth < lhs->depth) {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
lhs = lhs->parent;
} else {
edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
lhs = lhs->parent;
rhs = rhs->parent;
}
}
return lhs;
}

最终DominatorTree构建如下:

融合规则

接下来,根据融合规则融合算子,代码如下:

std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
const IndexedForwardGraph& graph) {
...
// run fusion algorithm.
for (int phase = 0; phase < 3; ++phase) {
this->RunFuse(graph, post_dom_tree, phase);
}
return std::move(groups_);
}

先了解下数据结构:

/*!
* \brief A partition of the graph marked by union find data structure.
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth,
size_t max_function_args)
: arena_(arena),
opt_level_(opt_level),
max_fuse_depth_(max_fuse_depth),
max_function_args_(max_function_args) {}
/*!
* \brief Group as a union find data structure.
*/
struct Group {
/*! \brief The parent in the union find data structure. */
Group* parent{nullptr};
/*! \brief The pattern of the group */
OpPatternKind pattern;
/*! \brief reference to the root node. */
const tvm::Object* root_ref{nullptr};
/*!
* \brief Reference to the anchor node,
* this field is not nullptr only if pattern is kOutEWiseFusable.
*/
const tvm::Object* anchor_ref{nullptr};
/*!
* \brief The number of nodes belonging to this group
*/
uint32_t num_nodes{1};
/*!
* \brief The number of function arguments belonging to this group
*/
size_t args_num{0}; /*! \brief Optional attributes to annotate the grouped function. */
runtime::Map<runtime::String, ObjectRef> attrs;
/*!
* \brief Find the group root, perform path compression
* \return The root type node.
*/
Group* FindRoot();
};
/*!
* \brief Partition a graph.
* \return group assignments of each node.
*/
std::vector<Group*> Partition(const IndexedForwardGraph& graph); private:
/*! \brief The internal arena for temporary space. */
support::Arena* arena_;
/*! \brief optimization level for fuse operation. */
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The maximum number of arguments in one fused function */
size_t max_function_args_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
/*! \brief The map with nodes which were postponed for fusing. */
std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>
postponed_fusing_map_;
...

Group是一个union Find数据结构(并查集),可以快速的找出两个节点是否属于同一组(分组);

Partion中首先对Group数据结构相关变量做初始化:每个节点对应一个group,并根据Group信息填充其字段:

void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
auto args_counter = [this](const tvm::Object* obj) {
size_t args_num = 0;
if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
for (auto& it : call_node->args) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
args_num++;
if (const auto* ttype =
GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
return args_num;
}; groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
auto* group_node = arena_->make<Group>();
group_node->pattern = graph_node->pattern;
group_node->root_ref = graph_node->ref;
// set anchor ref if necessary.
if (group_node->pattern == relay::kOutEWiseFusable) {
group_node->anchor_ref = graph_node->ref;
}
group_node->args_num = args_counter(graph_node->ref);
groups_[nid] = group_node;
}
}

始阅读RunFuse相关代码,在上文代码中看到RunFuse分成了3个阶段,我们依次来看:

void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) {
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// 取得graph_node, dom_node和group_node;
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
ICHECK(group_node != nullptr); // 遇到不可融合算子kOpaque,直接返回
if (group_node->pattern == kOpaque) continue; // 没有支配点信息的算子直接返回
if (dom_node->parent == nullptr) continue;
ICHECK(!graph_node->extern_ref); // 获取该节点后支配点graph索引
size_t dom_parent_gindex = dom_node->parent->gnode->index; // 此处先省略不看
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue; // 第三阶段处理逻辑(见下文)
.... // 当前节点已和其后支配点融合,则跳过
// Skip if current node is already fused to the parent.
if (groups_[dom_parent_gindex] != nullptr &&
group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
continue;
} // 跳过tuple相关操作
// Do not fuse into tuple for now
if (groups_[dom_parent_gindex]->pattern == kTuple) continue; // 第一阶段处理kOutEltwiseFusable,见下文
if (group_node->pattern == kOutEWiseFusable) {
......
}
// 每一阶段都会对 kEltwise 或 kBroadcast 处理,见下文
else if (group_node->pattern <= kBroadcast) {
......
}
// 第二阶段处理 kInjective 或 kTuple,见下文
else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
......
}
// kCommReduce相关逻辑
else {
// do nothing.
ICHECK(group_node->pattern == kCommReduce);
}
}
}

先看下CheckPath函数,主要用于遍历当前节点和其后支配节点之间的所有节点,并判断其是否满足给定的fcond:

template <typename F>
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
F fcond) {
if (visited_.count(src)) return true;
visited_.insert(src);
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
gnode = gnode->FindRoot();
if (!fcond(gnode->pattern, src == sink)) return false;
if (src == sink) return true;
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
} template <typename F>
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
F fcond) {
ICHECK(!src->extern_ref);
visited_.clear();
ICHECK(src != sink);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
if (!CheckPath_(link->value.node, sink, fcond)) return false;
}
return true;
}

如果CheckPath返回结果为True,一般会进行CommitFuse过程,其逻辑如下,CommitFuse_主要用于遍历过程,MergeFromTo用于更新Group的parent,pattern,num_nodes等字段;

void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
child = child->FindRoot();
parent = parent->FindRoot();
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
parent->args_num += child->args_num;
child->parent = parent;
// update anchor ref and pattern
if (child->anchor_ref != nullptr) {
ICHECK(parent->anchor_ref == nullptr);
parent->anchor_ref = child->anchor_ref;
parent->pattern = CombinePattern(child->pattern, parent->pattern);
}
} void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
Group* target) {
if (postpone_node_ != nullptr) {
postponed_fusing_map_.insert({postpone_node_, src});
return;
}
if (src == sink) return;
if (visited_.count(src)) return;
visited_.insert(src);
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
// merge the current group to the parent if possible.
MergeFromTo(gnode, target);
for (auto link = src->outputs.head; link != nullptr; link = link->next) {
CommitFuse_(link->value.node, sink, target);
}
} void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
Group* target = groups_[sink->index];
visited_.clear();
ICHECK(src != sink);
CommitFuse_(src, sink, target);
}

每一阶段都会处理kElemWise和kBroadcast:当前节点与其后支配点中的任意节点都满足patten<=kInjective且后支配点满足patten<=kOutEWiseFusable则可以融合;

else if (group_node->pattern <= kBroadcast) {
// Pre-condition: can only be fused to parent which is injective or reduction.
if (dom_node->parent != nullptr &&
(dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
// Check if all the intermediate ops are still broadcast.
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast anchor.
return kind <= kInjective;
} else {
return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
kind == kOutEWiseFusable);
}
};
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
}

第一阶段处理了kOutEWiseFusable:

// Try to fuse current node to its post-dominator.
if (group_node->pattern == kOutEWiseFusable) {
if (phase != 0) continue;
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
ICHECK(dom_node->parent->gnode != nullptr);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
}

当前节点为kOutEWiseFusable,后支配点为kElemWise,且两节点的路径中所有算子均满足patten<=kBroadcast则可以融合;

第二阶段处理了kInjective和kTuple:

else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if (phase != 1) continue;
// Check if all path are injective.
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}

当前节点为kInjectivekTuple且所有到后支配点路径的所有节点均满足patten <= kInjective,则可以融合;

第三阶段尝试将patten<=kInjective的算子融入kTuple中:

  if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
if (group_node->pattern > relay::kInjective) continue;
Group* dom_parent_group = groups_[dom_parent_gindex];
Group* dom_root_group = dom_parent_group->FindRoot();
// If dom node group has a tuple as its root, we do not fuse tuple fields into it
if (dom_root_group->pattern == relay::kTuple) continue;
if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
// Now we know the tuple has been fused into subsequent injective ops
auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
// dom_root_group can also be tuple, as in inception layers
// CheckPath is needed to avoid fusing two intermediate tuples
if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
CommitFuse(graph_node, dom_node->parent->gnode);
}
}
continue;
}

当前节点满足pattern<=kInjective,后支配点满足pattern=kTuple,且后支配点所属组的父节点满足pattern<=kInjective,则可以融合

其实经过第一阶段的处理,我们的示例已经被完全融合了: 1. 0号节点和1号节点没有parent信息,跳过; 2. 处理2号节点时,其后支配点是8,依次遍历了4、5、8、7号节点(如下图绿色虚线部分),均满足fcond条件,进行了融合,此时2,4,5,7节点的parent均为8;8号节点的num_nodes为5; 3. 当遍历到3号节点(kOpaque)时,其后支配点是4,不满足fcond条件,不融合; 4. 当遍历到6号节点时,其后支配点是7,满足fcond条件,进行融合,6的parent被设为7节点的parant即8,8号节点的num_nodes此时为6; 5. 当遍历其余节点时,均已被fuse,直接返回;



到目前为止,所有的融合信息均已存在Group节点中,接下来只需根据这些信息更新IR表示即可。

更新IR

更新IR过程是通过Pass完成的,遍历过程中根据上文得到的Group信息进行适当的调整即可。我们观察算子融合更新后的IR:插入了一个FunctionNode和一个对应的CallNode来表示融合关系,接下来我们跟进到代码中看下这是如何实现的。

def @main(%x: Tensor[(1, 3, 16, 16), float32] /* ty=Tensor[(1, 3, 16, 16), float32] */, %weight: Tensor[(3, 3, 3, 3), float32] /* ty=Tensor[(3, 3, 3, 3), float32] */) -> Tensor[(1, 3, 14, 14), float32] {
%4 = fn (%p0: Tensor[(1, 3, 16, 16), float32] /* ty=Tensor[(1, 3, 16, 16), float32] */, %p1: Tensor[(3, 3, 3, 3), float32] /* ty=Tensor[(3, 3, 3, 3), float32] */, %p2: Tensor[(1, 3, 14, 14), float32] /* ty=Tensor[(1, 3, 14, 14), float32] */, Primitive=1) -> Tensor[(1, 3, 14, 14), float32] {
%0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 14, 14), float32] */;
%1 = add(%0, %p2) /* ty=Tensor[(1, 3, 14, 14), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 3, 14, 14), float32] */;
%3 = multiply(%0, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3, 14, 14), float32] */;
add(%2, %3) /* ty=Tensor[(1, 3, 14, 14), float32] */
} /* ty=fn (Tensor[(1, 3, 16, 16), float32], Tensor[(3, 3, 3, 3), float32], Tensor[(1, 3, 14, 14), float32]) -> Tensor[(1, 3, 14, 14), float32] */;
%4(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 3, 14, 14), float32] */) /* ty=Tensor[(1, 3, 14, 14), float32] */
}

FuseMutator继承自MixedModelMutator,并对 FunctionNode, CallNode等的遍历方式进行了重写;

MixedModelMuator的遍历针对dataflow node(如CallNode,TupleNode等)是一个post-topolgy的遍历;

class FuseMutator : private MixedModeMutator {
private:
int fuse_opt_level_;
size_t max_fuse_depth_;
bool link_params_;
/*! \brief The group assignment map. */
std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_;
/* \brief Internal group information map. */
std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
......
};

首先看下FunctionNode的处理方式:对于primitive function跳过处理,否则进入父类的处理逻辑中,即依次处理args和body;

// Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) {
if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
return GetRef<Expr>(fn_node);
} else {
return ExprMutator::VisitExpr_(fn_node);
}
}

对于args的处理,依次处理%x和%weight,处理结果存储到其成员变量std::unordered_map<Expr, Expr> memo_成员变量中; 接下来处理body,按照post-topolgy的顺序依次遍历 CallNode(conv, ...) -> ConstantNode -> CallNode(add, ...) -> CallNode(relu, ...) -> ConstantNode -> CallNode(multiply, ...) -> CallNode(add, ...);

对ConstantNode的处理直接继承自父类ExprMutator::VisitExpr_(const ConstantNode* op),在memo_中存储一份引用;

对CallNode的处理是核心如下: 1. 找到当前Call节点所属Group; 2. 构造输入参数:当输入参数所属Group不同于当前Group,则创建形参和实参; 3. 构造对应的CallNode节点; 4. 如果当前节点不是Group的root->ref节点,则直接返回,否则根据GroupInfo中存储的形参和实参构造一个新的FunctionNode和CallNode。

部分代码如下:

// 存储每一个Group对应的实参和形参
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
// The parameters of the function.
Array<Var> params;
// The arguments to call the functions.
Array<Expr> arguments;
// Get a new parameter or allocate an old one
Var GetOrAllocParam(const Expr& expr, const Type& type) {
// run linear scan as most fused groups contain only a few inputs.
for (size_t i = 0; i < arguments.size(); ++i) {
if (expr.same_as(arguments[i])) return params[i];
}
// create a new parameter.
std::ostringstream os;
os << "p" << params.size();
auto var = Var(os.str(), type);
params.push_back(var);
arguments.push_back(expr);
return var;
}
}; Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
GraphPartitioner::Group* current_group) {
Array<Expr> new_args;
for (auto arg : args) {
auto* arg_group = gmap_.at(arg.get())->FindRoot();
auto type = arg->checked_type();
Expr new_arg = this->Mutate(arg);
if (current_group != arg_group) {
if (!link_params_ || new_arg.as<ConstantNode>() == nullptr) {
Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
new_args.push_back(param);
} else {
new_args.push_back(new_arg);
}
} else {
new_args.push_back(new_arg);
}
}
return new_args;
} Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
......
const GroupInfo& ginfo = ginfo_[group];
auto func = Function(ginfo.params, body, ret_type, {});
func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
......
return Call(func, ginfo.arguments, Attrs());
} // Transform calls.
Expr Rewrite_(const CallNode* call, const Expr& post) {
if (call->op.as<OpNode>()) {
......
// 找到其所属group
auto* ret_group = gmap_.at(call)->FindRoot();
// 构造输入args列表
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
// 构造CallNode节点
auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span); if (ret_group->root_ref == call) {
// This is the root of the group
// create the new call node.
// 构造FunctionNode节点和对应的CallNode节点
return MakeNewFunction(ret_group, call->checked_type(), new_call);
} else {
// This is an intermediate node of a fused function
// simply return the new call.
return std::move(new_call);
}
} else {
return ExprMutator::VisitExpr_(call);
}
}

参考:https://zhuanlan.zhihu.com/p/589619468

TVM Pass优化 -- 算子融合(FuseOps)的更多相关文章

  1. SystemML大规模机器学习,优化算子融合方案的研究

    SystemML大规模机器学习,优化算子融合方案的研究 摘要 许多大规模机器学习(ML)系统允许通过线性代数程序指定定制的ML算法,然后自动生成有效的执行计划.在这种情况下,优化的机会融合基本算子的熔 ...

  2. TVM图优化与算子融合

    TVM图优化与算子融合 计算图的定义 Computational graphs: a common way to represent programs in deep learning framewo ...

  3. MXNet 图优化与算子融合

    MXNet 图优化与算子融合Graph Optimization and Quantization based on subgraph and MKL-DNN Purpose MKL-DNN引入了两个 ...

  4. TVM Pass IR如何使用

    TVM Pass IR如何使用 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,并应用于TVM堆栈中IR的不同层. Relay / t ...

  5. 如何使用TVM Pass红外线

    如何使用TVM Pass红外线 随着Relay / tir中优化遍数的增加,执行并手动维护其依赖关系变得很棘手.引入了一个基础结构来管理优化过程,将其应用于TVM堆栈中IR的不同层. Relay / ...

  6. TVM图优化(以Op Fusion为例)

    首先给出一个TVM 相关的介绍,这个是Tianqi Chen演讲在OSDI18上用的PPThttps://files.cnblogs.com/files/jourluohua/Tianqi-Chen- ...

  7. TVM优化Deep Learning GPU算子

    TVM优化Deep Learning GPU算子 高效的深度学习算子是深度学习系统的核心.通常,这些算子很难优化,需要HPC专家付出巨大的努力. 端到端张量IR / DSL堆栈TVM使这一过程变得更加 ...

  8. TVM优化GPU机器翻译

    TVM优化GPU机器翻译 背景 神经机器翻译(NMT)是一种自动化的端到端方法,具有克服传统基于短语的翻译系统中的弱点的潜力.最近,阿里巴巴集团正在为全球电子商务部署NMT服务. 将Transform ...

  9. TensorFlow+TVM优化NMT神经机器翻译

    TensorFlow+TVM优化NMT神经机器翻译 背景 神经机器翻译(NMT)是一种自动化的端到端方法,具有克服传统基于短语的翻译系统中的弱点的潜力.本文为全球电子商务部署NMT服务. 目前,将Tr ...

  10. GEMM与AutoKernel算子优化

    GEMM与AutoKernel算子优化 随着AI技术的快速发展,深度学习在各个领域得到了广泛应用.深度学习模型能否成功在终端落地应用,满足产品需求,一个关键的指标就是神经网络模型的推理性能.一大波算法 ...

随机推荐

  1. 动态代理到AOP

    动态代理 代理(proxy)是一种设计模式,通过了目标对象的另外访问方法,即通过代理对象访问目标对象.动态代理是再程序运行时动态地生成一个代理类代替原本的类.该类会拦截对目标对象的方法调用 为什么使用 ...

  2. Sqoop2 - [01] 安装部署

    1.启动Hadoop集群 2.将Sqoop2的安装包上传到合适的目录中解压 tar -zxvf sqoop-1.99.7-bin-hadoop200.tar.gz 3.修改Hadoop的配置文件cor ...

  3. H3C S520 V3 端口流量镜像

    背景: 最近公司需要采集某工业PLC设备报文,临时查询了一下如何使用H3C交换机配置流量镜像. PLC地址: 192.168.3.213 MAC: e0:dc:a0:5c:47:2f (可通过ARP ...

  4. DBeaver连接mysql时Public Key Retrieval is not allowed错误

    前言 DBeaver 连接 mysql 时,报错:Public Key Retrieval is not allowed 解决 在新建连接的时候,驱动属性里设置 allowPublicKeyRetri ...

  5. goland无法识别包

    新建 Go 项目时,一定要通过 "File -> New -> Project..." 方式建立,千万不要通过 "File -> Open", ...

  6. oracle调整sga、pga大小

    展开修改sga大小1-1查看当前sga大小SQL> show parameter sga1-2修改sga_max_size为24GSQL> alter system set sga_max ...

  7. NextJS CVE-2025-29927 安全漏洞

    NextJS CVE-2025-29927 安全漏洞 CVE-2025-29927 是一个存在于 Next.js 框架中的关键安全漏洞.该漏洞允许攻击者通过伪造或篡改 x-middleware-sub ...

  8. AI 应用思考

    之前看到过一个理论,创新技术的三个阶段:新技术创造-精英服务-平民化 技术扩散的三阶段理论模型 1. 创新垄断期(精英创造阶段)技术革命初期,创新活动高度依赖知识密集型投入.AI发展呈现"分 ...

  9. leetcode每日一题:对角线上的质数

    题目 2614. 对角线上的质数 给你一个下标从 0 开始的二维整数数组 nums . 返回位于 nums 至少一条 对角线 上的最大 质数 .如果任一对角线上均不存在质数,返回 0 . 注意: 如果 ...

  10. python练习-爬虫

    场景: 1.网址hppt://xxx.yyy.zzz.cn2.打开网页后显示 : 3.填上姓名 身份证和验证码,点击查询后,返回查询结果. 4.页面有cookie. 方案一: 程序中嵌入浏览器根据网址 ...