如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。

一、CINN 框架

CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:

  • Build(ir::Module&, string& code)
  • Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
} /**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = ""); void ExportObject(const std::string& path); std::string GetSourceCode(const ir::Module& module); void BuildDefault(const ir::Module& module); /**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name); private:
void CompileCudaModule(const ir::Module& module, const std::string& code = ""); void CompileX86Module(const ir::Module& module); explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_; #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};

Build()方法中,是通过target来进行逻辑分发的:

void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}

我们来详细研习下 CompileCudaModule() 函数的实现逻辑:

  • step 1: 调用 SplitCudaAndHostModule()ir::Module 切分成 host_moduledevice_module
  • step 2: 借助 CodeGenCUDA_Dev 模块,对 device_module 进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  • step 3: 构造一个 nvrtc::Compiler 对象,将 source_code编译为 ptx 中间代码
  • step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是 CUBIN 或者 PTXkind 类型
  • step 5: 根据 device_module 中的 fn_namecuda_module_ 中取出对应的fn_kernelCUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__void*)
  • step 6: 以 RuntimeSymbols 构造 ExecutionEngine,负责将 CUDA KernelHost 端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_); // step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module); // step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code); // step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); // step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
} // step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module); }

二、TVM 框架

runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa); for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
} std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

其中 NVRTCCompile 的实现:

std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
} compile_params.push_back("-arch=compute_" + cc); if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath(); compile_params.push_back(include_option);
} for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog)); return ptx;
}

可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:

@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format) with open(temp_code, "w") as out_file:
out_file.write(code) file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch] if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str") cmd += ["-o", file_target]
cmd += [temp_code] # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg) with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data

三、参考资料

  1. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
  2. https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
  3. https://github.com/NVIDIA/jitify/issues/43
  4. https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits

AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章

  1. go中panic源码解读

    panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...

  2. go中waitGroup源码解读

    waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...

  3. etcd中watch源码解读

    etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...

  4. java中jdbc源码解读

    在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...

  5. go中errgroup源码解读

    errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...

  6. Vue 源码解读(8)—— 编译器 之 解析(上)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...

  7. Vue 源码解读(10)—— 编译器 之 生成渲染函数

    前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...

  8. go 中 select 源码阅读

    深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...

  9. Vue 源码解读(8)—— 编译器 之 解析(下)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...

  10. Vue 源码解读(9)—— 编译器 之 优化

    前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...

随机推荐

  1. 算法学习笔记【5】| ST表

    ST表 Part 1:ST表解决的问题是什么 ST 表可以用来解决RMQ(区间最值问题)等可重复贡献的问题. ST表基于倍增的思想来实现. Part 2:ST表的实现 ST表通过 O(nlog⁡n)& ...

  2. Unit 1 Computer hardware【石家庄铁道大学-专业英语课程复习资料】

    Unit 1 Computer hardware 1.Introduction of computer A computer is a machine that can be instructed t ...

  3. #单调栈,树状数组#CF1635F Closest Pair

    题目 设 \(f(x,y)=|a_x-a_y|*(w_x+w_y)\),其中 \(a\) 单调递增 多组询问求 \(\min_{l\leq l'<r'\leq r}\{f(l',r')\}\) ...

  4. #树,搜索#NOIP2020.9.26模拟tom

    分析 考虑最极端的情况也就是TOM天天吃早餐肠或者晚餐肠, 那么早餐肠和晚餐肠应分别构成一个互不相交连通块, 所以题目转换成是否有一个点的子树大小为\(a\)或\(b\), 将这个点与它父亲的边断开就 ...

  5. OpenHarmony 4.1 Release版本正式发布,邀您体验

    春风轻拂的4月,OpenAtom OpenHarmony(以下简称"OpenHarmony")4.1 Release版本如期而至,开发套件同步升级到API 11 Release. ...

  6. 代码覆盖率检查工具 -- Coverage,简单使用

    Coverage 一个专门用来检查代码覆盖率的工具,他的使用非常简单,有两种使用方法:[命令行运行,配合测试套件使用] 安装: pip install coverage 一.准备素材 main.py ...

  7. 机器学习&深度学习 操作tips

    1. 在运行程序时,报错如下: usage: run.py [-h] --model MODEL [--embedding EMBEDDING] [--word WORD] run.py: error ...

  8. docker 应用篇————容器共享数据卷[十五]

    前言 简单介绍一下多个容器间容器卷共享. 正文 先启动上一节的test:2.0 这个镜像. docker run --name test01 -it test:2.0 /bin/bash 然后 ctr ...

  9. 重走py 之路 ——字典和集合(二)

    前言 python 中有6大标准类型: 数字(Number) 字符串(String) 列表(List) 元组(Tumple) 集合(Set) 字典(Dictionary) 前面已经介绍了上面4种,还有 ...

  10. RestfulApi 学习笔记——分页和排序(六)

    前言 分页和排序时一些非常常规的操作,同样也有一些我们注意的点. 正文 分页 先来谈及分页. 看下前端传递的参数. public class EmployeeDtoParameters { priva ...