AI编译器CINN v.s TVM 中CodeGen 源码解读
如下的技术点梳理仅以「日常优化工作」为牵引点,涉及哪个模块,就具体去看哪个模块的代码。
一、CINN 框架
CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:
Build(ir::Module&, string& code)Lookup(string& fn_name)
class Compiler final {
public:
static std::unique_ptr<Compiler> Create(const Target& target) {
return std::unique_ptr<Compiler>(new Compiler(target));
}
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = "");
void ExportObject(const std::string& path);
std::string GetSourceCode(const ir::Module& module);
void BuildDefault(const ir::Module& module);
/**
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
void* Lookup(absl::string_view fn_name);
private:
void CompileCudaModule(const ir::Module& module, const std::string& code = "");
void CompileX86Module(const ir::Module& module);
explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
CINN_DISALLOW_COPY_AND_ASSIGN(Compiler);
private:
Target target_;
std::unique_ptr<ExecutionEngine> engine_;
#ifdef CINN_WITH_CUDA
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};
在Build()方法中,是通过target来进行逻辑分发的:
void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code); // <----- GPU 上编译逻辑
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module); // <------ X86 CPU 上编译逻辑
} else {
CINN_NOT_IMPLEMENTED
}
}
我们来详细研习下 CompileCudaModule() 函数的实现逻辑:
- step 1: 调用
SplitCudaAndHostModule()将ir::Module切分成host_module和device_module - step 2: 借助
CodeGenCUDA_Dev模块,对device_module进行代码生成,得到source_code;也支持用户直接通过code参数指定 - step 3: 构造一个
nvrtc::Compiler对象,将source_code编译为ptx中间代码 - step 4: 以
ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN或者PTX的kind类型 - step 5: 根据
device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*) - step 6: 以
RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel与Host端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
// step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
// step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
std::string source_code;
CodeGenCUDA_Dev codegen(target_);
source_code = codegen.Compile(device_module);
// step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
backends::nvrtc::Compiler compiler;
auto ptx = compiler(source_code);
// step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
using runtime::cuda::CUDAModule;
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
// step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
RuntimeSymbols symbols;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
CHECK(fn_kernel);
symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
}
// step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
engine_->Link<CodeGenCUDA_Host>(host_module);
}
二、TVM 框架
runtime::Module BuildCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
const auto* f_enter = Registry::Get("target.TargetEnterScope");
(*f_enter)(target);
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
const auto* f_exit = Registry::Get("target.TargetExitScope");
(*f_exit)(target);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);
其中 NVRTCCompile 的实现:
std::string NVRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
compile_params.push_back("-arch=compute_" + cc);
if (include_path) {
std::string include_option = "--include-path=" + FindCUDAIncludePath();
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));
return ptx;
}
可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:
@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
temp = utils.tempdir()
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target_format)
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
elif isinstance(arch, str):
cmd += ["-arch", arch]
if options:
if isinstance(options, str):
cmd += [options]
elif isinstance(options, list):
cmd += options
else:
raise ValueError("options must be str or list of str")
cmd += ["-o", file_target]
cmd += [temp_code]
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
# However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
# Because it is hard to do runtime compiler detection, we require nvcc is configured
# correctly by default.
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
raise RuntimeError("Compilation error: empty result is generated")
return data
三、参考资料
- https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#virtual-architecture-feature-list
- https://forums.developer.nvidia.com/t/using-curand-inside-nvrtc-jit-compiled-kernels/193826
- https://github.com/NVIDIA/jitify/issues/43
- https://github.com/NVIDIA/libcudacxx/blob/main/include/cuda/std/detail/libcxx/include/type_traits
AI编译器CINN v.s TVM 中CodeGen 源码解读的更多相关文章
- go中panic源码解读
panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...
- go中waitGroup源码解读
waitGroup源码刨铣 前言 WaitGroup实现 noCopy state1 Add Wait 总结 参考 waitGroup源码刨铣 前言 学习下waitGroup的实现 本文是在go ve ...
- etcd中watch源码解读
etcd中watch的源码解析 前言 client端的代码 Watch newWatcherGrpcStream run newWatchClient serveSubstream server端的代 ...
- java中jdbc源码解读
在jdbc中一个重要的接口类就是java.sql.Driver,其中有一个重要的方法:Connection connect(String url, java.util.Propeties info); ...
- go中errgroup源码解读
errgroup 前言 如何使用 实现原理 WithContext Go Wait 错误的使用 总结 errgroup 前言 来看下errgroup的实现 如何使用 func main() { var ...
- Vue 源码解读(8)—— 编译器 之 解析(上)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了上下两篇,所以在阅读本篇文章时请同时打开 Vue 源码解读(8)-- 编译器 之 解析(下)一起阅读. 前言 V ...
- Vue 源码解读(10)—— 编译器 之 生成渲染函数
前言 这篇文章是 Vue 编译器的最后一部分,前两部分分别是:Vue 源码解读(8)-- 编译器 之 解析.Vue 源码解读(9)-- 编译器 之 优化. 从 HTML 模版字符串开始,解析所有标签以 ...
- go 中 select 源码阅读
深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...
- Vue 源码解读(8)—— 编译器 之 解析(下)
特殊说明 由于文章篇幅限制,所以将 Vue 源码解读(8)-- 编译器 之 解析 拆成了两篇文章,本篇是对 Vue 源码解读(8)-- 编译器 之 解析(上) 的一个补充,所以在阅读时请同时打开 Vu ...
- Vue 源码解读(9)—— 编译器 之 优化
前言 上一篇文章 Vue 源码解读(8)-- 编译器 之 解析 详细详解了编译器的第一部分,如何将 html 模版字符串编译成 AST.今天带来编译器的第二部分,优化 AST,也是大家常说的静态标记. ...
随机推荐
- Jenkins设置定时触发器执行任务
1. 选中任务,选择配置/构建触发器,选择定时构建 2. 填写定时器参数,格式说明如下,共五个参数,参数之间空格隔开,不需要填的直接*号即可. 此处d定时任务的格式遵循 cron 的语法(可以与 c ...
- Python企业面试题2 —— 基础篇
1. re 的match和search区别? re.match 尝试从字符串的起始位置匹配一个模式,如果不是起始位置匹配成功的话,match()就返回none. re.search 扫描整个字符串并返 ...
- #Raney引理,圆排列#洛谷 6672 [清华集训2016] 你的生命已如风中残烛
题目 分析 转化一下条件,就是 \(\sum{w_i}\geq i\),将所有牌权值减一,那就是 \(\sum{w'_i}\geq 0\) 根据Raney引理,总和为 1 的数列,在循环移位时,只有一 ...
- #轮廓线dp,模型转换#洛谷 3226 [HNOI2012]集合选数
题目 问有多少个集合 \(S\) 是 \([1,n]\) 的子集, 并且 \(\forall a,b\in S,a|b\),满足 \(\frac{b}{a}\neq \{2,3\}\) 分析 可以发现 ...
- #约数#洛谷 4296 [AHOI2007]密码箱
题目 给定\(n(n\leq 2*10^9)\),求 \[\sum_{x=1}^n[x^2\bmod n==1] \] 分析 首先当\(n=1\)的时候需要特判, 否则1和\(n-1\)一定是答案, ...
- 中文GPTS,字节中文扣子Coze使用全教程
字节出自己的GPTS了,名字英文名叫coze,中文名叫"扣子".和OpenAI的GPTS类似.具有可定制性和完成特定任务的强大功能,它提供了一种新的GPT方式,可以让用户根据自己的 ...
- 一文总结ACE代码框架
一.前言 ACE_Engine框架是OpenAtom OpenHarmony(简称"OpenHarmony")的UI开发框架,为开发者提供在进行应用UI开发时所必需的各种组件,以及 ...
- C#实现的下拉多选框,下拉多选树,多级节点
今天给大家上个硬货,下拉多选框,同时也是下拉多选树,支持父节点跟子节点!该控件是基于Telerik控件封装实现的,所以大家在使用的过程中需要引用Telerik.WinControls.dll.Tele ...
- 在Mac系统上使用Qt调用摄像头不出图解决方法
需求:在Mac系统上,调用摄像头,实现旋转.缩放.处理视频帧等功能 问题:使用获取视频帧的方法,在Mac上调不起来摄像头 解决方法: 将视频窗口(QVideoWidget)和视频帧(QVideoFra ...
- Ubuntu部署Django一:环境搭建
前言: Ubuntu系统上部署django,使用的部署方案是 Ubuntu + django + uwsgi + nginx Ubuntu系统版本用的是 ubuntu-20.04.2.0-desk ...