AI编译器CINN v.s TVM 中CodeGen 源码解读
如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。
一、CINN 框架
CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:
Build(ir::Module&, string& code)Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
}
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = "");
void ExportObject(const std::string& path);
std::string GetSourceCode(const ir::Module& module);
void BuildDefault(const ir::Module& module);
/**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name);
private:
void CompileCudaModule(const ir::Module& module, const std::string& code = "");
void CompileX86Module(const ir::Module& module);
explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
CINN_DISALLOW_COPY_AND_ASSIGN(Compiler);
private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};
在Build()方法中,是通过target来进行逻辑分发的:
void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}
我们来详细研习下 CompileCudaModule() 函数的实现逻辑:
- step 1: 调用
SplitCudaAndHostModule()将ir::Module切分成host_module和device_module - step 2: 借助
CodeGenCUDA_Dev模块,对device_module进行代码生成,得到source_code;也支持用户直接通过code参数指定 - step 3: 构造一个
nvrtc::Compiler对象,将source_code编译为ptx中间代码 - step 4: 以
ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN或者PTX的kind类型 - step 5: 根据
device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*) - step 6: 以
RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel与Host端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
// step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module);
// step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
// step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
// step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
}
// step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module);
}
二、TVM 框架
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
其中 NVRTCCompile 的实现:
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
compile_params.push_back("-arch=compute_" + cc);
if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath();
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));
return ptx;
}
可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format)
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
三、参考资料
- https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
- https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
- https://github.com/NVIDIA/jitify/issues/43
- https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits
AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章
- go中panic源码解读
panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...
- go中waitGroup源码解读
waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...
- etcd中watch源码解读
etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...
- java中jdbc源码解读
在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...
- go中errgroup源码解读
errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...
- Vue 源码解读(8)—— 编译器 之 解析(上)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...
- Vue 源码解读(10)—— 编译器 之 生成渲染函数
前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...
- go 中 select 源码阅读
深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...
- Vue 源码解读(8)—— 编译器 之 解析(下)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...
- Vue 源码解读(9)—— 编译器 之 优化
前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...
随机推荐
- sys_stat_statements 扩展使用介绍
sys_stat_statements 模块提供追踪服务器所执行的所有SQL语句的执行统计信息,可以用于统计数据库的资源开销,如分析TOP SQL. KingbaseES V8R6版本该插件已经内置化 ...
- #单调栈#CodeChef Meteor
METEORAK 分析 设 \(dp[l][r]\) 表示第 \(l\) 到 \(r\) 行的答案,可以发现它由 \(f[l][r],dp[l][r+1],dp[l+1][r]\) 转移而来. 关键就 ...
- 如何在OpenHarmony上使用SeetaFace2人脸识别库?
简介 相信大部分同学们都已了解或接触过OpenAtom OpenHarmony(以下简称"OpenHarmony")了,但你一定没在OpenHarmony上实现过人脸识别功能,跟着 ...
- 组合数学——Min-Max容斥
Min-Max 容斥,即 $$\max(S)=\sum_{T\in S,T\neq\emptyset}(-1)^{|T|-1}\min(T)$$ 接下来证明上面那个式子是对的.定义 \(S\) 中共有 ...
- SQL DELETE 语句:删除表中记录的语法和示例,以及 SQL SELECT TOP、LIMIT、FETCH FIRST 或 ROWNUM 子句的使用
SQL DELETE 语句 SQL DELETE 语句用于删除表中的现有记录. DELETE 语法 DELETE FROM 表名 WHERE 条件; 注意:在删除表中的记录时要小心!请注意DELETE ...
- 安装HTMLTestRunner库
安装 HTMLTestRunner 库的方法非常简单,直接 pip 就可以了 pip install html-testRunner 在 https://pypi.org/ 中可以直接搜索到,并且官 ...
- 一键部署openGauss2.0.1 CentOS 7.6
一键部署 openGauss2.0.1[CentOS 7.6] 本文档目的是为了帮助高校学生提供基于 CentOS7.6 操作系统,实现 openGauss 数据库一键式安装的脚本. 该脚本执行成功后 ...
- HarmonyOS 设备管理开发:USB 服务开发指导
基本概念 USB服务是应用访问底层的一种设备抽象概念.开发者根据提供的USB API,可以获取设备列表.控制设备访问权限.以及与连接的设备进行数据传输.控制命令传输等. 运作机制 USB服务系统包 ...
- 开发指导—利用CSS动画实现HarmonyOS动效(一)
注:本文内容分享转载自HarmonyOS Developer官网文档 一. CSS语法参考 CSS是描述HML页面结构的样式语言.所有组件均存在系统默认样式,也可在页面CSS样式文件中对组件.页面自 ...
- redis 简单整理——客户端案例分析[十八]
前言 简单整理一下客户端案例分析. 正文 现象一: 服务端现象:Redis主节点内存陡增,几乎用满maxmemory,而从节点 内存并没有变化. 客户端现象:客户端产生了OOM异常,也就是Redis主 ...