tvm-多线程代码生成和运行
本文链接
https://www.cnblogs.com/wanger-sjtu/p/16818492.html
调用链
tvm搜索算子在需要多线程运行的算子,是在codegen阶段时插入TVMBackendParallelLaunch的调用。
TVMBackendParallelLaunch 是tvm的线程池并行化入口,具体如下
/*!
 * \brief The callback function to execute a parallel lambda
 * \param task_id the task id of the function. //这里实际就是线程池线程编码,对应第几个线程
 * \param penv The parallel environment backs the execution. // num_task, sync
 * \param cdata The supporting closure data.
 */
typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
 * \brief Backend function for running parallel jobs.
 *
 * \param flambda The parallel function to be launched.
 * \param cdata The closure data. // 可以认为时循环的变量 codegen时生成
 * \param num_task Number of tasks to launch, can be 0, means launch
 *           with all available threads. // codegen 时写入的是0,运行时根据配置写入
 *
 * \return 0 when no error is thrown, -1 when failure happens
 */
int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task);
flambda的调用在单线程和多线程下略有区别。
单线程运行时
if (num_workers == 1) {
    std::atomic<int32_t> sync_counter{0};
    TVMParallelGroupEnv env;
    env.num_task = 1;
    env.sync_handle = &sync_counter;
    (*flambda)(0, &env, cdata);
    return 0;
  }
多线程运行时
// launcher->Init(flambda, cdata, num_task, need_sync != 0);
this->cdata = cdata;
this->flambda = flambda;
this->env.num_task = num_task;
while (queue->Pop(&task, spin_count)) {
    ICHECK(task.launcher != nullptr);
    TVMParallelGroupEnv* penv = &(task.launcher->env);
    void* cdata = task.launcher->cdata;
    if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) {
      task.launcher->SignalJobFinish();
    } else {
      task.launcher->SignalJobError(task.task_id);
    }
  }
可以看到 待并行函数中 TVMParallelGroupEnv* penv 包含了实际的运行时线程,运行时可以根据这个确定每个线程的工作区间和步长。
cdata则是线程运行时需要变量信息,闭包变量。
总结
对要并行的函数,实际上是按照lambda表达式的方式生成的。FTVMParallelLambda 的输入参数前两个是运行时确定的,第三个是捕获的外部变量。
codegen 过程
下面验证一下上述的猜测。
codegen过程中,实际上是在遍历tir Stmt的AST,因为生成的循环都是基于For的,调用过程也比较简单了。
void CodeGenCPU::VisitStmt_(const ForNode* op)  // ->
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
                        op->thread_binding, op->annotations),
                    0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str());   // ->
CodeGenCPU::VisitStmt_(const ForNode* op);
当遍历到For节点时, 根据属性判断是否并行加速。这里只分析加速场景。此时parallel_env_.penv == nullptr 创建多线程调用函数,进入CreateParallelLaunch函数。
然后 再生成 For的遍历逻辑。this->VisitStmt(body); 这里的body其实还是For ,这时候就进入
} else {
      // already in parallel env.
前文的猜测也在这里得到验证。
void CodeGenCPU::VisitStmt_(const ForNode* op) {
  ICHECK(is_zero(op->min));
  if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
    CodeGenLLVM::VisitStmt_(op);
  } else if (op->kind == ForKind::kParallel) {
    if (parallel_env_.penv == nullptr) {
      CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body,
                               op->thread_binding, op->annotations),
                           0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str());
    } else {
      // already in parallel env.
      ICHECK(parallel_env_.task_id.defined());
      ICHECK(parallel_env_.num_task.defined());
      ICHECK(parallel_env_.penv != nullptr);
      DataType t = op->extent.dtype();
      PrimExpr num_task = cast(t, parallel_env_.num_task);
      PrimExpr task_id = cast(t, parallel_env_.task_id);
      ICHECK(!parallel_env_.in_parallel_loop)
          << "Nested parallel loop is not supported by threadpool, try fuse them instead";
      parallel_env_.in_parallel_loop = true;
      if (parallel_env_.stride_pattern) {
        CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task),
                        op->loop_var, op->body);
      } else {
        PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
        PrimExpr begin = min(task_id * step, op->extent);
        PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
        CreateSerialFor(MakeValue(begin), MakeValue(end),
                        llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body);
      }
      parallel_env_.in_parallel_loop = false;
      ++parallel_env_.parallel_loop_count;
    }
  } else {
    LOG(FATAL) << "cannot handle for type " << op->kind;
  }
}
/*
    const Stmt& body  For 循环的statement
    int num_task, 这里设置的是0,根据运行时参数确定使用线程
    std::string name
*/
void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::string name) {
  // closure data
  llvm::Function* f =
      llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage,
                             "__tvm_parallel_lambda", module_.get());
  SetTargetAttributes(f);
  // allocate and setup the closure, call the closure. //For 循环内部变量。这里需要声明一下
  Array<Var> vfields = tir::UndefinedVars(body, {});
  uint64_t nbytes;
  TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); // 可以认为时循环的变量
#if TVM_LLVM_VERSION >= 90
  auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
  auto launch_callee = RuntimeTVMParallelLaunch();
#endif
  llvm::BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
      launch_callee,
      {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
  // Setup the closure function.
  auto* lambda_entry =
      llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f);
  builder_->SetInsertPoint(lambda_entry);
  auto it = f->arg_begin();
  llvm::Value* task_id = &(*it++);
  task_id->setName("task_id");
  llvm::Value* penv = &(*it++);
  cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
  // setup new variable map, swap it with current var context.
  std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
  UnpackClosureData(cdata, vfields, &new_vmap);
  // setup parallel env
  ParallelEnv par_env;
  par_env.task_id = Var("task_id", DataType::Int(32));
  par_env.num_task = Var("num_task", DataType::Int(32));
  new_vmap[par_env.task_id.get()] = task_id;
  new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
      t_int32_,
      builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}),
      "num_task");
  par_env.penv = penv;
  auto new_analyzer = std::make_unique<arith::Analyzer>();
  std::swap(function_, f);
  std::swap(parallel_env_, par_env);
  std::swap(analyzer_, new_analyzer);
  std::swap(var_map_, new_vmap);
  this->VisitStmt(body);
  builder_->CreateRet(ConstInt32(0));
  // swap the var map back, now we are back on track.
  std::swap(var_map_, new_vmap);
  std::swap(analyzer_, new_analyzer);
  std::swap(parallel_env_, par_env);
  std::swap(function_, f);
  ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch";
  builder_->SetInsertPoint(par_launch_end);
}
												
											tvm-多线程代码生成和运行的更多相关文章
- ExecutorService java多线程分割list运行
		
调用方法 int threadNum = 7; while(true) { List<FaceAnalyseImage> list = faceAnalyseImageMapper.sel ...
 - Microsoft SDK 中Sample案例之Amcap項目 的运行方法(转)
		
http://blog.csdn.net/erick08/article/details/7194575 Microsoft SDK 中Sample之Amcap 的运行方法 写这篇文章的由 ...
 - [转帖]运行时库(runtime library)
		
运行时库(runtime library) https://blog.csdn.net/xitie8523/article/details/82712105 没学过这些东西 或者当时上课没听 又或者 ...
 - TVM如何训练TinyML
		
TVM如何训练TinyML 机器学习研究人员和从业人员对"裸机"(低功耗,通常没有操作系统)设备产生了广泛的兴趣.尽管专家已经有可能在某些裸机设备上运行某些模型,但是为各种设备优化 ...
 - .NET基础拾遗(5)多线程开发基础
		
Index : (1)类型语法.内存管理和垃圾回收基础 (2)面向对象的实现和异常的处理基础 (3)字符串.集合与流 (4)委托.事件.反射与特性 (5)多线程开发基础 (6)ADO.NET与数据库开 ...
 - java之多线程 二
		
线程的生命周期: 当线程被创建并被启动时,它既不是一启动就进入了执行状态,在线程的生命周期中,它要经过new(新建),就绪(Runnable),运行(Running),阻塞(Blocked),dead ...
 - Python 多线程教程:并发与并行
		
转载于: https://my.oschina.net/leejun2005/blog/398826 在批评Python的讨论中,常常说起Python多线程是多么的难用.还有人对 global int ...
 - WebDriver多线程并发
		
要想多线程并发的运行WebDriver,必须同时满足2个条件,首先你的测试程序是多线程,其次需要用到Selenium Server.下载位置如下图: 下载下来后是一个jar包,需要在命令行中运行.里面 ...
 - JavaEE基础(二十四)/多线程
		
1.多线程(多线程的引入) 1.什么是线程 线程是程序执行的一条路径, 一个进程中可以包含多条线程 多线程并发执行可以提高程序的效率, 可以同时完成多项工作 2.多线程的应用场景 红蜘蛛同时共享屏幕给 ...
 - php模拟多线程
		
一:应该知道的: php本身是不支持多线, 但是php的好搭档,apache和linux是支持的,故lamp才是最佳组合,还在使用win服务器的现在知道为什么要用linux吧.既然是模拟的, 就不是真 ...
 
随机推荐
- ICMP隐蔽隧道攻击分析与检测(二)
			
• ICMP协议流量特征分析 一.ASCII与HEX对照转换表 二.ICMP正常流量分析 经常使用的ping命令就是基于ICMP协议,Windows系统下ping默认传输的是:"abcdef ...
 - 排队论——系统运行指标的R语言实现
			
排队是在日常生活中经常遇到的现象,如顾客到商店购买物品.病人到医院看病常常要排队.此时要求服务的数量超过服务机构(服务台.服务员等)的容量.也就是说,到达的顾客不能立即得到服务,因而出现了排队现象.这 ...
 - python入门教程之二十四Python MySQL - mysql-connector 驱动
			
MySQL 是最流行的关系型数据库管理系统,如果你不熟悉 MySQL,可以阅读我们的 MySQL 教程. 本章节我们为大家介绍使用 mysql-connector 来连接使用 MySQL, mysql ...
 - [Linux]CentOS7搭建/配置:YUM仓库/源[本地源/Web源(Apache HTTP(D))/自建源仓库]
			
若想搞懂整个配置过程和原理,就按照章节(1 / 2)一步一步地来. 若想直接一步到位,不想花过多时间,尽快配好,就直接看附件章节. 什么是yum源? Yum(全称为 Yellow dog Update ...
 - [Java]排序算法>插入排序>【直接插入排序】(O(N*N)/稳定/N较小/有序/顺序存储+链式存储)
			
1 直接插入排序 1.1 算法思想 插入排序的基本思想是:每一趟将1个待排序的记录,按其关键字的大小插入到已经排好序的一组记录的适当位置上,直到所有待排序记录全部插入为止. 1.2 算法特征 属于[插 ...
 - Redis读书笔记(二)
			
Redis对象系统 Redis对象 字符串(String)的底层实现方式 直接保存整数值:字符串对象保存的是整数值,且可以用long类型来表示. embstr编码的SDS:字符串对象保存的是一个长度小 ...
 - 《爆肝整理》保姆级系列教程-玩转Charles抓包神器教程(14)-Charles过滤网络请求
			
1.简介 在日常工作测试中,经常要抓包看请求的request,response是不是传的对,返回的字段值对不对,众多的请求中看得眼花缭乱,如何找到自己想要的请求,那么我们就需要过滤请求.Charles ...
 - TypeScript 学习笔记 — 数组常见的类型转换操作记录(十四)
			
获取长度 length type LengthOfTuple<T extends any[]> = T["length"]; type A = LengthOfTupl ...
 - Html/css 列表项 区分列表首尾
			
列表项,有时需要判断列表首尾,来筛选设置样式 如上图,三个项有间隔,怎么保证设置了列表项之间的距离后,整体还水平居中显示呢? .item:not(:first-child) { margin-left ...
 - 【配置教程】撑起月6亿PV开源监控解决方案
			
上次分享过<一个.Net Core开源监控解决方案,支持Redis.Elasticsearch.SqlServer>,这是Stack Overflow 开源的监控产品,基于.Net Cor ...