如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。

一、CINN 框架

CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:

  • Build(ir::Module&, string& code)
  • Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
} /**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = ""); void ExportObject(const std::string& path); std::string GetSourceCode(const ir::Module& module); void BuildDefault(const ir::Module& module); /**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name); private:
void CompileCudaModule(const ir::Module& module, const std::string& code = ""); void CompileX86Module(const ir::Module& module); explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_; #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};

Build()方法中,是通过target来进行逻辑分发的:

void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}

我们来详细研习下 CompileCudaModule() 函数的实现逻辑:

  • step 1: 调用 SplitCudaAndHostModule()ir::Module 切分成 host_moduledevice_module
  • step 2: 借助 CodeGenCUDA_Dev 模块,对 device_module 进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  • step 3: 构造一个 nvrtc::Compiler 对象,将 source_code编译为 ptx 中间代码
  • step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是 CUBIN 或者 PTXkind 类型
  • step 5: 根据 device_module 中的 fn_namecuda_module_ 中取出对应的fn_kernelCUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__void*)
  • step 6: 以 RuntimeSymbols 构造 ExecutionEngine,负责将 CUDA KernelHost 端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_); // step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module); // step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code); // step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); // step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
} // step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module); }

二、TVM 框架

runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa); for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
} std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

其中 NVRTCCompile 的实现:

std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
} compile_params.push_back("-arch=compute_" + cc); if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath(); compile_params.push_back(include_option);
} for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog)); return ptx;
}

可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:

@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format) with open(temp_code, "w") as out_file:
out_file.write(code) file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch] if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str") cmd += ["-o", file_target]
cmd += [temp_code] # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg) with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data

三、参考资料

  1. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
  2. https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
  3. https://github.com/NVIDIA/jitify/issues/43
  4. https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits

AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章

  1. go中panic源码解读

    panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...

  2. go中waitGroup源码解读

    waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...

  3. etcd中watch源码解读

    etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...

  4. java中jdbc源码解读

    在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...

  5. go中errgroup源码解读

    errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...

  6. Vue 源码解读(8)—— 编译器 之 解析(上)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...

  7. Vue 源码解读(10)—— 编译器 之 生成渲染函数

    前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...

  8. go 中 select 源码阅读

    深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...

  9. Vue 源码解读(8)—— 编译器 之 解析(下)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...

  10. Vue 源码解读(9)—— 编译器 之 优化

    前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...

随机推荐

  1. 使用fiddler抓取HTTPS的数据包(抓取App端的数据包)

    众所周知,我们在做接口测试的时候有两种情况: 第一种是先拿到接口测试规范文档,再去做接口测试. 第二种是没有接口文档,只有通过自己抓包. 那么说到抓包,就不得不说抓包工具,对于浏览器web端,我们只需 ...

  2. 测试开发之系统篇-Docker容器安装

    前面文章我们讲到,容器是运行在宿主机上的一个进程,多个容器之间使用同一个宿主机上的操作系统内核.此处以Ubuntu20.04系统为例,介绍Docker容器引擎的安装过程. 安装 安装依赖. sudo ...

  3. 使用CMake启用RUNPATH特性

    使用CMake,启用RUNPATH特性,可以参考官方帖子. 如下源码来自于上述帖子. CMAKE_MINIMUM_REQUIRED(VERSION 2.8 FATAL_ERROR) PROJECT(R ...

  4. OpenHarmony父子组件单项同步使用:@Prop装饰器

      @Prop装饰的变量可以和父组件建立单向的同步关系.@Prop装饰的变量是可变的,但是变化不会同步回其父组件. 说明: 从API version 9开始,该装饰器支持在ArkTS卡片中使用. 概述 ...

  5. HMS Core手语服务荣获2022中国互联网大会“特别推荐案例”:助力建设数字社会

    11月15日,HMS Core手语服务在2022(第二十一届)中国互联网大会 "互联网助力经济社会数字化转型"案例评选活动中,荣获"特别推荐案例". 经过一年多 ...

  6. pc=mobile+pad自适应布局:页面结构与打开方式

    pc=mobile+pad自适应布局 在这篇文章,咱们重点聊聊自适应布局的页面结构,以及打开页面的几种方式.关于pc=mobile+pad自适应布局的起源.概念.效果,参见文章:自适应布局:pc = ...

  7. Windows Server 2008 R2之升级IE8

    前言 先需求将Windows Server 2008 R2的IE8升级至IE9,需要安装系统补丁. 安装补丁 补丁包版本 KB2454826 下载地址 https://www.catalog.upda ...

  8. 如何使用 Grafana 监控文件系统状态

    当 JuiceFS 文件系统部署完成并投入生产环境,接下来就需要着手解决一个非常重要的问题 -- 如何实时监控它的运行状态?毕竟,它可能正在为关键的业务应用或容器工作负载提供持久化存储支持,任何小小的 ...

  9. 树模型-label boosting-GBDT

    GBDT GBDT是boosting系列算法的代表之一,其核心是 梯度+提升+决策树. GBDT回归问题 通俗的理解: 先来个通俗理解:假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时 ...

  10. 【布局进阶】巧用 :has & drop-shadow 实现复杂布局效果

    最近,群里聊到了一个很有意思的布局效果.大致效果如下所示,希望使用 CSS 实现如下所示的布局效果: 正常而言,我们的 HTML 结构大致是如下所示: <div class="g-co ...