如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。

一、CINN 框架

CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:

  • Build(ir::Module&, string& code)
  • Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
} /**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = ""); void ExportObject(const std::string& path); std::string GetSourceCode(const ir::Module& module); void BuildDefault(const ir::Module& module); /**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name); private:
void CompileCudaModule(const ir::Module& module, const std::string& code = ""); void CompileX86Module(const ir::Module& module); explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_; #ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};

Build()方法中,是通过target来进行逻辑分发的:

void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}

我们来详细研习下 CompileCudaModule() 函数的实现逻辑:

  • step 1: 调用 SplitCudaAndHostModule()ir::Module 切分成 host_moduledevice_module
  • step 2: 借助 CodeGenCUDA_Dev 模块,对 device_module 进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  • step 3: 构造一个 nvrtc::Compiler 对象,将 source_code编译为 ptx 中间代码
  • step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是 CUBIN 或者 PTXkind 类型
  • step 5: 根据 device_module 中的 fn_namecuda_module_ 中取出对应的fn_kernelCUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__void*)
  • step 6: 以 RuntimeSymbols 构造 ExecutionEngine,负责将 CUDA KernelHost 端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_); // step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module); // step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code); // step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); // step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
} // step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module); }

二、TVM 框架

runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa); for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
} std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

其中 NVRTCCompile 的实现:

std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
} compile_params.push_back("-arch=compute_" + cc); if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath(); compile_params.push_back(include_option);
} for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size)); std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog)); return ptx;
}

可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:

@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format) with open(temp_code, "w") as out_file:
out_file.write(code) file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch] if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str") cmd += ["-o", file_target]
cmd += [temp_code] # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg) with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data

三、参考资料

  1. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
  2. https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
  3. https://github.com/NVIDIA/jitify/issues/43
  4. https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits

AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章

  1. go中panic源码解读

    panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...

  2. go中waitGroup源码解读

    waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...

  3. etcd中watch源码解读

    etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...

  4. java中jdbc源码解读

    在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...

  5. go中errgroup源码解读

    errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...

  6. Vue 源码解读(8)—— 编译器 之 解析(上)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...

  7. Vue 源码解读(10)—— 编译器 之 生成渲染函数

    前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...

  8. go 中 select 源码阅读

    深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...

  9. Vue 源码解读(8)—— 编译器 之 解析(下)

    特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...

  10. Vue 源码解读(9)—— 编译器 之 优化

    前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...

随机推荐

  1. SpringMVC转发和重定向的区别是什么

    转发和重定向 视图解析器 <!--视图解析器--> <bean class="org.springframework.web.servlet.view.InternalRe ...

  2. ET介绍——C#更好的协程

    更好的协程 上文讲了一串回调就是协程,显然这样写代码,增加逻辑,插入逻辑非常容易出错.我们需要利用异步语法把这个异步回调的形式改成同步的形式,幸好C#已经帮我们设计好了,看代码 // example2 ...

  3. #计数#CF10C Digital Root

    题目 定义\(d(x)\)为\(x\)的数位和嵌套,直至\(0\leq d(x)<10\) 询问在\([1\sim n]\)中有多少个三元组\((a,b,c)\)满足 \[ab\neq c,d( ...

  4. docker笔记之安装

    本文于2017年上半年完成,发布在个人博客网站上. 考虑个人博客因某种原因无法修复,于是在博客园安家,之前发布的文章逐步搬迁过来. 最近由于工作关系,接触到了docker技术.为了对docker有更多 ...

  5. SQL CREATE INDEX 语句- 提高数据库检索效率的关键步骤

    SQL CREATE INDEX 语句 SQL CREATE INDEX 语句用于在表中创建索引. 索引用于比其他方式更快地从数据库中检索数据.用户无法看到索引,它们只是用于加速搜索/查询. 注意: ...

  6. 挑战吧,HarmonyOS应用开发工程师

      一年一度属于工程师的专属节日1024已过,但程序员多重活动持续进行中~ 参与活动即有机会获得HUAWEI Freebuds 5i 耳机等精美礼品! 点击"阅读原文"查看更多活动 ...

  7. Linux0.12内核源码解读(2)-Bootsect.S

    大家好,我是呼噜噜,在上一篇文章聊聊x86计算机启动发生的事?我们了解了x86计算机启动过程,MBR.0x7c00是什么?其中当bios引导结束后,操作系统接过计算机的控制权后,发生了哪些事?本文将揭 ...

  8. Python3中pip3命令的用法介绍及安装配置

    第一节:pip3是什么?有啥用? pip3:(Python3 Install Package ),这个英文全称是我为了更好的理解这个命令这么叫的,官方没有这对个命令的全称的解释:) python 支持 ...

  9. redis 简单整理——持久化之AOF[二十]

    前言 简单介绍一下AOF. 正文 AOF(append only file)持久化:以独立日志的方式记录每次写命令, 重启时再重新执行AOF文件中的命令达到恢复数据的目的. AOF的主要作用 是解决了 ...

  10. WPF开发随笔收录-获取程序专有内存

    分享一个C#获取程序当前所占用的内存大小的方法,实测跟任务管理器上的内存值一样 /// <summary> /// 性能计数器组件类 /// </summary> privat ...