如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。

一、CINN 框架

CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:

  • Build(ir::Module&, string& code)
  • Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
} /**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = ""); void ExportObject(const std::string& path); std::string GetSourceCode(const ir::Module& module); void BuildDefault(const ir::Module& module); /**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name); private:
void CompileCudaModule(const ir::Module& module, const std::string& code = ""); void CompileX86Module(const ir::Module& module); explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_; #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};

Build()方法中,是通过target来进行逻辑分发的:

void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}

我们来详细研习下 CompileCudaModule() 函数的实现逻辑:

  • step 1: 调用 SplitCudaAndHostModule()ir::Module 切分成 host_moduledevice_module
  • step 2: 借助 CodeGenCUDA_Dev 模块,对 device_module 进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  • step 3: 构造一个 nvrtc::Compiler 对象,将 source_code编译为 ptx 中间代码
  • step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是 CUBIN 或者 PTXkind 类型
  • step 5: 根据 device_module 中的 fn_namecuda_module_ 中取出对应的fn_kernelCUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__void*)
  • step 6: 以 RuntimeSymbols 构造 ExecutionEngine,负责将 CUDA KernelHost 端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_); // step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module); // step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code); // step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); // step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
} // step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module); }

二、TVM 框架

runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa); for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
} std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

其中 NVRTCCompile 的实现:

std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
} compile_params.push_back("-arch=compute_" + cc); if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath(); compile_params.push_back(include_option);
} for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog)); return ptx;
}

可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:

@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format) with open(temp_code, "w") as out_file:
out_file.write(code) file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch] if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str") cmd += ["-o", file_target]
cmd += [temp_code] # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg) with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data

三、参考资料

  1. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
  2. https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
  3. https://github.com/NVIDIA/jitify/issues/43
  4. https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits

AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章

  1. go中panic源码解读

    panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...

  2. go中waitGroup源码解读

    waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...

  3. etcd中watch源码解读

    etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...

  4. java中jdbc源码解读

    在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...

  5. go中errgroup源码解读

    errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...

  6. Vue 源码解读(8)—— 编译器 之 解析(上)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...

  7. Vue 源码解读(10)—— 编译器 之 生成渲染函数

    前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...

  8. go 中 select 源码阅读

    深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...

  9. Vue 源码解读(8)—— 编译器 之 解析(下)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...

  10. Vue 源码解读(9)—— 编译器 之 优化

    前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...

随机推荐

  1. virtualbox安装windows10出现OOBE,卡在OOBE。

    参照 https://zhuanlan.zhihu.com/p/419237209 https://www.0z.gs/win/781.html 文档 https://learn.microsoft. ...

  2. jenkens2权威指南

    第1章 Jenkins简介 Jenkins 2是什么 JobConfigHistory:这个插件可以追溯XML配置的历史版本信息, 并且允许你查看每次变更的内容. JenkinsFile Jenkin ...

  3. 汇编语言-使用BIOS进行键盘输入和磁盘读写

    int9中断例程对键盘输入的处理   键盘输入将引发9号中断,BIOS提供了int9中断例程.CPU在9号中断发生后,执行int 9中断例程,从60h端口读出扫描码,并将其转化为相应的ASCII码或状 ...

  4. #轮廓线dp#洛谷 1879 [USACO06NOV]Corn Fields G

    题目 分析 考虑状压dp在\(n\leq 21\)的情况下会TLE, 设\(dp[n][m][S]\)表示当前正在处理\((n,m)\)这个格子 并且轮廓线状态为\(S\)的方案数, 考虑可行状态最多 ...

  5. #构造,二分#[AGC006B] [AGC006D] Median Pyramid

    Easy Hard 分析(Easy) 若\(X=1\)或\(X=2n-1\)无解,否则在正中间构造\(X-1,X,X+1\), 其余位置升序铺入剩余数, 若\(X-1\)左侧数大于\(X-1\)那么\ ...

  6. Bootstrap实战 - 评论列表

    一.介绍 社交媒体网站盛行,人们常常会使用评论表达自己的观点,评论功能已然成为网站的一部分. 二.知识点 2.1 媒体对象 官方解释:这是一个抽象的样式,用以构建不同类型的组件,这些组件都具有在文本内 ...

  7. 【直播回顾】OpenHarmony知识赋能六期第三课—OpenHarmony智能家居项目之控制面板功能实现

    7月14日晚上19点,知识赋能第六期第三节直播 <OpenHarmony智能家居项目之控制面板功能实现> ,在OpenHarmony开发者成长计划社群内成功举行. 本次直播是"O ...

  8. CSP-S2021江西自评分数(10-26)

    娱乐性质,不负责任 在机房大佬的努力下,评测完了 总表 姓名 编号 总分 airport bracket palin traffic JX-00001 JX-00001 0 0 0 0 0 JX-00 ...

  9. HarmonyOS SDK,赋能开发者实现更具象、个性化开发诉求

    随着移动互联网的逐步成熟,用户的需求越来越细化.鸿蒙生态为开发者提供的HarmonyOS SDK开放能力,高效赋能美团外卖等合作伙伴实现更具象.个性化的开发诉求,给用户提供更丰富便捷的体验. 点击链接 ...

  10. 编译安装cmake,linux编译安装cmake

    cmake官网:https://cmake.org/ cmake官网下载地址:https://cmake.org/download/ 现在Linux版本最新版是:cmake-3.28.0-rc5.ta ...