AI编译器CINN v.s TVM 中CodeGen 源码解读
如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。
一、CINN 框架
CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:
Build(ir::Module&, string& code)Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
}
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = "");
void ExportObject(const std::string& path);
std::string GetSourceCode(const ir::Module& module);
void BuildDefault(const ir::Module& module);
/**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name);
private:
void CompileCudaModule(const ir::Module& module, const std::string& code = "");
void CompileX86Module(const ir::Module& module);
explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
CINN_DISALLOW_COPY_AND_ASSIGN(Compiler);
private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};
在Build()方法中,是通过target来进行逻辑分发的:
void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}
我们来详细研习下 CompileCudaModule() 函数的实现逻辑:
- step 1: 调用
SplitCudaAndHostModule()将ir::Module切分成host_module和device_module - step 2: 借助
CodeGenCUDA_Dev模块,对device_module进行代码生成,得到source_code;也支持用户直接通过code参数指定 - step 3: 构造一个
nvrtc::Compiler对象,将source_code编译为ptx中间代码 - step 4: 以
ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN或者PTX的kind类型 - step 5: 根据
device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*) - step 6: 以
RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel与Host端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
// step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module);
// step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
// step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
// step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
}
// step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module);
}
二、TVM 框架
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
其中 NVRTCCompile 的实现:
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
compile_params.push_back("-arch=compute_" + cc);
if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath();
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));
return ptx;
}
可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format)
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
三、参考资料
- https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
- https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
- https://github.com/NVIDIA/jitify/issues/43
- https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits
AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章
- go中panic源码解读
panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...
- go中waitGroup源码解读
waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...
- etcd中watch源码解读
etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...
- java中jdbc源码解读
在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...
- go中errgroup源码解读
errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...
- Vue 源码解读(8)—— 编译器 之 解析(上)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...
- Vue 源码解读(10)—— 编译器 之 生成渲染函数
前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...
- go 中 select 源码阅读
深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...
- Vue 源码解读(8)—— 编译器 之 解析(下)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...
- Vue 源码解读(9)—— 编译器 之 优化
前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...
随机推荐
- 算法学习笔记【5】| ST表
ST表 Part 1:ST表解决的问题是什么 ST 表可以用来解决RMQ(区间最值问题)等可重复贡献的问题. ST表基于倍增的思想来实现. Part 2:ST表的实现 ST表通过 O(nlogn)& ...
- Unit 1 Computer hardware【石家庄铁道大学-专业英语课程复习资料】
Unit 1 Computer hardware 1.Introduction of computer A computer is a machine that can be instructed t ...
- #单调栈,树状数组#CF1635F Closest Pair
题目 设 \(f(x,y)=|a_x-a_y|*(w_x+w_y)\),其中 \(a\) 单调递增 多组询问求 \(\min_{l\leq l'<r'\leq r}\{f(l',r')\}\) ...
- #树,搜索#NOIP2020.9.26模拟tom
分析 考虑最极端的情况也就是TOM天天吃早餐肠或者晚餐肠, 那么早餐肠和晚餐肠应分别构成一个互不相交连通块, 所以题目转换成是否有一个点的子树大小为\(a\)或\(b\), 将这个点与它父亲的边断开就 ...
- OpenHarmony 4.1 Release版本正式发布,邀您体验
春风轻拂的4月,OpenAtom OpenHarmony(以下简称"OpenHarmony")4.1 Release版本如期而至,开发套件同步升级到API 11 Release. ...
- 代码覆盖率检查工具 -- Coverage,简单使用
Coverage 一个专门用来检查代码覆盖率的工具,他的使用非常简单,有两种使用方法:[命令行运行,配合测试套件使用] 安装: pip install coverage 一.准备素材 main.py ...
- 机器学习&深度学习 操作tips
1. 在运行程序时,报错如下: usage: run.py [-h] --model MODEL [--embedding EMBEDDING] [--word WORD] run.py: error ...
- docker 应用篇————容器共享数据卷[十五]
前言 简单介绍一下多个容器间容器卷共享. 正文 先启动上一节的test:2.0 这个镜像. docker run --name test01 -it test:2.0 /bin/bash 然后 ctr ...
- 重走py 之路 ——字典和集合(二)
前言 python 中有6大标准类型: 数字(Number) 字符串(String) 列表(List) 元组(Tumple) 集合(Set) 字典(Dictionary) 前面已经介绍了上面4种,还有 ...
- RestfulApi 学习笔记——分页和排序(六)
前言 分页和排序时一些非常常规的操作,同样也有一些我们注意的点. 正文 分页 先来谈及分页. 看下前端传递的参数. public class EmployeeDtoParameters { priva ...