AI编译器CINN v.s TVM 中CodeGen 源码解读
如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。
一、CINN 框架
CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:
Build(ir::Module&, string& code)Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
}
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = "");
void ExportObject(const std::string& path);
std::string GetSourceCode(const ir::Module& module);
void BuildDefault(const ir::Module& module);
/**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name);
private:
void CompileCudaModule(const ir::Module& module, const std::string& code = "");
void CompileX86Module(const ir::Module& module);
explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
CINN_DISALLOW_COPY_AND_ASSIGN(Compiler);
private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};
在Build()方法中,是通过target来进行逻辑分发的:
void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}
我们来详细研习下 CompileCudaModule() 函数的实现逻辑:
- step 1: 调用
SplitCudaAndHostModule()将ir::Module切分成host_module和device_module - step 2: 借助
CodeGenCUDA_Dev模块,对device_module进行代码生成,得到source_code;也支持用户直接通过code参数指定 - step 3: 构造一个
nvrtc::Compiler对象,将source_code编译为ptx中间代码 - step 4: 以
ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN或者PTX的kind类型 - step 5: 根据
device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*) - step 6: 以
RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel与Host端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
// step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module);
// step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
// step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
// step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
}
// step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module);
}
二、TVM 框架
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
其中 NVRTCCompile 的实现:
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
compile_params.push_back("-arch=compute_" + cc);
if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath();
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));
return ptx;
}
可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format)
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
三、参考资料
- https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
- https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
- https://github.com/NVIDIA/jitify/issues/43
- https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits
AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章
- go中panic源码解读
panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...
- go中waitGroup源码解读
waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...
- etcd中watch源码解读
etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...
- java中jdbc源码解读
在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...
- go中errgroup源码解读
errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...
- Vue 源码解读(8)—— 编译器 之 解析(上)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...
- Vue 源码解读(10)—— 编译器 之 生成渲染函数
前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...
- go 中 select 源码阅读
深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...
- Vue 源码解读(8)—— 编译器 之 解析(下)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...
- Vue 源码解读(9)—— 编译器 之 优化
前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...
随机推荐
- 使用fiddler抓取HTTPS的数据包(抓取App端的数据包)
众所周知,我们在做接口测试的时候有两种情况: 第一种是先拿到接口测试规范文档,再去做接口测试. 第二种是没有接口文档,只有通过自己抓包. 那么说到抓包,就不得不说抓包工具,对于浏览器web端,我们只需 ...
- 测试开发之系统篇-Docker容器安装
前面文章我们讲到,容器是运行在宿主机上的一个进程,多个容器之间使用同一个宿主机上的操作系统内核.此处以Ubuntu20.04系统为例,介绍Docker容器引擎的安装过程. 安装 安装依赖. sudo ...
- 使用CMake启用RUNPATH特性
使用CMake,启用RUNPATH特性,可以参考官方帖子. 如下源码来自于上述帖子. CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) PROJECT(R ...
- OpenHarmony父子组件单项同步使用:@Prop装饰器
@Prop装饰的变量可以和父组件建立单向的同步关系.@Prop装饰的变量是可变的,但是变化不会同步回其父组件. 说明: 从API version 9开始,该装饰器支持在ArkTS卡片中使用. 概述 ...
- HMS Core手语服务荣获2022中国互联网大会“特别推荐案例”:助力建设数字社会
11月15日,HMS Core手语服务在2022(第二十一届)中国互联网大会 "互联网助力经济社会数字化转型"案例评选活动中,荣获"特别推荐案例". 经过一年多 ...
- pc=mobile+pad自适应布局:页面结构与打开方式
pc=mobile+pad自适应布局 在这篇文章,咱们重点聊聊自适应布局的页面结构,以及打开页面的几种方式.关于pc=mobile+pad自适应布局的起源.概念.效果,参见文章:自适应布局:pc = ...
- Windows Server 2008 R2之升级IE8
前言 先需求将Windows Server 2008 R2的IE8升级至IE9,需要安装系统补丁. 安装补丁 补丁包版本 KB2454826 下载地址 https://www.catalog.upda ...
- 如何使用 Grafana 监控文件系统状态
当 JuiceFS 文件系统部署完成并投入生产环境,接下来就需要着手解决一个非常重要的问题 -- 如何实时监控它的运行状态?毕竟,它可能正在为关键的业务应用或容器工作负载提供持久化存储支持,任何小小的 ...
- 树模型-label boosting-GBDT
GBDT GBDT是boosting系列算法的代表之一,其核心是 梯度+提升+决策树. GBDT回归问题 通俗的理解: 先来个通俗理解:假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时 ...
- 【布局进阶】巧用 :has & drop-shadow 实现复杂布局效果
最近,群里聊到了一个很有意思的布局效果.大致效果如下所示,希望使用 CSS 实现如下所示的布局效果: 正常而言,我们的 HTML 结构大致是如下所示: <div class="g-co ...