如果模型中有些算子不被ONNX算子库支持,我们就需要利用ONNXRuntime提供的API手动添加新算子。在官方文档中已经对如何添加定制算子进行了介绍(https://onnxruntime.ai/docs/reference/operators/add-custom-op.html ),这里我们主要把源码中对应的流程给捋清楚。

添加定制算子(Custom Operators)主要分为三步:

  1. 创建一个定制算子域(CusttomOpDomain);
  2. 创建一个定制算子(CustomOp),并将该算子添加到定制算子域中;
  3. 将定制算子域添加到 SessionOption 中

首先看看源码中给出的定制算子样例:

// file path: onnxruntime/test/shared_lib/custom_op_utils.h

// 首先定义定制算子的核
struct MyCustomKernel {
MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/, void* compute_stream)
: ort_(ort), compute_stream_(compute_stream) {
} void Compute(OrtKernelContext* context); private:
Ort::CustomOpApi ort_;
void* compute_stream_;
}; // 然后定义定制算子的各个操作,各个成员函数均已实现,其中 CreateKernel 会返回前面定义的算子核对象
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info, compute_stream_); };
const char* GetName() const { return "Foo"; };
const char* GetExecutionProviderType() const { return provider_; }; size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
// Both the inputs need to be necessarily of float type
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}; size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; private:
const char* provider_;
void* compute_stream_;
};

在上面代码中,我们看到定制算子继承自 Ort::CustomOpBase<MyCustomOp, MyCustomKernel>,这种扩展类作为模板基类的模板参数的方式又被称为CRTP,接着深入到这个模板类内部:

// file path: include/onnxruntime/core/session/onnxruntime_cxx_api.h

template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); }; OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); }; OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
} // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; } // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
// (inputs and outputs are required by default)
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
} OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
};

这里的 CustomOpBase 又继承自 OrtCustomOp

// include/onnxruntime/core/session/onnxruntime_c_api.h

struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp; struct OrtCustomOp {
uint32_t version; // Must be initialized to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api,
_In_ const OrtKernelInfo* info); // Returns the name of the op
const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); // Returns the type of the execution provider, return nullptr to use CPU execution provider
const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); // Returns the count and types of the input & output tensors
ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op);
ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); // Op kernel callbacks
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); // Returns the characteristics of the input & output tensors
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
};

可以发现,OrtCustomOp 中定义了定制算子应该实现的模式,其中的一系列回调函数由其派生类一一实现,比如上文提到的 CustomOpBase 在其构造函数中,以 lambda 函数的方式实现各个回调函数。

至此,我们已经完整地梳理了定义定制算子在源码内部是如何实现的,接下来介绍如何将定义好的定制算子使用起来。

从如下官方测试代码开始分析:

// file path: onnxruntime/test/shared_lib/test_inference.cc

TEST(CApiTest, custom_op_handler) {
std::cout << "Running custom op inference" << std::endl; std::vector<Input> inputs(1);
Input& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; // prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; // 创建定制算子(MyCustomOp)
#ifdef USE_CUDA
cudaStream_t compute_stream = nullptr; // 声明一个 cuda stream
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); // 创建一个 cuda stream
MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider, compute_stream};
#else
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider, nullptr};
#endif // 创建定制算子域(CustomOpDomain)
Ort::CustomOpDomain custom_op_domain("");
// 在定制算子域中添加定制算子
custom_op_domain.Add(&custom_op); // 进入 TestInference
#ifdef USE_CUDA
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
custom_op_domain, nullptr, nullptr, false, compute_stream);
cudaStreamDestroy(compute_stream);
#else
TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
custom_op_domain, nullptr);
#endif
}

以上代码需要特别注意的是,需要根据宏(USE_CUDA)用来判断是否使用CUDA。如果使用 CUDA:

  • 当模型运行在GPU上,而插入的是 CPU 定制算子,那么 ONNXRuntime 会在 CPU 定制算子前后分别插入两个操作 MemcpyToHost、MemcpyFromHost,这两个操作负责内存拷贝,即首先从 Device 拷贝到 Host,再从 Host 拷贝到 Device;
  • 如果插入的是 GPU 定制算子,为了确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步,它们必须使用同一个 CUDA 计算流。具体细节在下一个代码继续分析。

这里创建 cuda stream 的方式是 cudaStreamCreateWithFlags,该函数和 cudaStreamCreate 不同,后者在多次调用时是串行方式执行,而前者可同步执行。如果将参数 cudaStreamNonBlocking 替换为 cudaStreamDefault,则 cudaStreamCreateWithFlags 的行为将和 cudaStreamCreate 相同。【参考内容:CUDA 5.0 中cudaStreamCreateWithFlags 的用法

无论是否使用CDUA,我们都需要创建定制算子(MyCustomOp)。

进入 TestInference 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<OutT>& expected_values_y,
int provider_type,
OrtCustomOpDomain* custom_op_domain_ptr,
const char* custom_op_library_filename,
void** library_handle = nullptr,
bool test_session_creation_only = false,
void* cuda_compute_stream = nullptr) {
Ort::SessionOptions session_options; if (provider_type == 1) {
#ifdef USE_CUDA
std::cout << "Running simple inference with cuda provider" << std::endl;
auto cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
session_options.AppendExecutionProvider_CUDA(cuda_options);
#else
ORT_UNUSED_PARAMETER(cuda_compute_stream);
return;
#endif
} else if (provider_type == 2) {
#ifdef USE_DNNL
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1));
std::cout << "Running simple inference with dnnl provider" << std::endl;
#else
return;
#endif
} else if (provider_type == 3) {
#ifdef USE_NUPHAR
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options,
/*allow_unaligned_buffers*/ 1, ""));
std::cout << "Running simple inference with nuphar provider" << std::endl;
#else
return;
#endif
} else {
std::cout << "Running simple inference with default provider" << std::endl;
}
if (custom_op_domain_ptr) {
session_options.Add(custom_op_domain_ptr);
} if (custom_op_library_filename) {
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options,
custom_op_library_filename, library_handle));
} // if session creation passes, model loads fine
Ort::Session session(env, model_uri.c_str(), session_options); // caller wants to test running the model (not just loading the model)
if (!test_session_creation_only) {
// Now run
auto default_allocator = std::make_unique<MockedOrtAllocator>(); //without preallocated output tensor
RunSession<OutT>(default_allocator.get(),
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
nullptr);
//with preallocated output tensor
Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(),
expected_dims_y.data(), expected_dims_y.size()); //test it twice
for (int i = 0; i != 2; ++i)
RunSession<OutT>(default_allocator.get(),
session,
inputs,
output_name,
expected_dims_y,
expected_values_y,
&value_y);
}
}

前文提到,如果对应EP是CUDA,需要确保 ORT 的 CUDA kernels 和定制 CUDA kernels 之间的同步。为了实现这一目标,首先通过 CreateDefaultOrtCudaProviderOptionsWithCustomStream 函数将新创建的 CUDA 计算流以 OrtCudaProviderOptions 的形式传递给 SessionOptions:

OrtCUDAProviderOptions cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
session_options.AppendExecutionProvider_CUDA(cuda_options)

之后,将定制算子域也添加到 SessionOptions 中:

if (custom_op_domain_ptr) {
session_options.Add(custom_op_domain_ptr);
}

至此,SessionOptions 已经构建完成,下面创建 Session 并通过 model_uri 加载模型:

Ort::Session session(env, model_uri.c_str(), session_options);

这里的(1)Ort::Session 是在 onnxruntime_cxx_api.h 文件中声明的类,(2)对应的构造函数在 onnxruntime_cxx_inline.h 中实现,(3)实现方式是进一步调用 onnxruntime_c_api.h 中定义的 API,该 API 也仅仅是声明,(4)最终对应的实现在 onnxruntime_c_api.cc 文件中:

// (1) include/onnxruntime/core/session/onnxruntime_cxx_api.h
struct Session : Base<OrtSession> {
explicit Session(std::nullptr_t) {}
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
} // (2) include/onnxruntime/core/session/onnxruntime_cxx_inline.h
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
} // (3) include/onnxruntime/core/session/onnxruntime_c_api.h
ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); // (4) onnxruntime/core/session/onnxruntime_c_api.cc
ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr<onnxruntime::InferenceSession> sess;
OrtStatus* status = nullptr;
*out = nullptr; ORT_TRY {
ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); *out = reinterpret_cast<OrtSession*>(sess.release());
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = OrtApis::CreateStatus(ORT_FAIL, e.what());
});
} return status;
API_IMPL_END
}

可以发现,Ort::Session 内部还是调用了 onnxruntime::InferenceSession

扯远了,下面回归主题。

创建 Session 完成之后,便开始运行,进入 RunSession 函数内部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& dims_y,
const std::vector<OutT>& values_y,
Ort::Value* output_tensor) { // 构建模型输入
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
for (size_t i = 0; i < inputs.size(); i++) {
input_names.emplace_back(inputs[i].name);
ort_inputs.emplace_back(
Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()),
inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
} // 运行 RUN
std::vector<Ort::Value> ort_outputs;
if (output_tensor)
session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, output_tensor, 1);
else {
ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), 1u);
output_tensor = &ort_outputs[0];
} auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
ASSERT_EQ(type_info.GetShape(), dims_y);
size_t total_len = type_info.GetElementCount();
ASSERT_EQ(values_y.size(), total_len); OutT* f = output_tensor->GetTensorMutableData<OutT>();
for (size_t i = 0; i != total_len; ++i) {
ASSERT_EQ(values_y[i], f[i]);
}
}

这里使用了一些GTest中的断言来判定运行结果是否符合预期。

至此,我们已经完整地分析了定制算子从定义到使用的全部流程。

文档中还提到了 Contrib ops,这类算子归属于 contrib ops domain,是嵌入到 runtime 内部的,对于一些使用低频的算子最好不要加入这个域中,否则会导致运行时库(runtime library)过大。

官方文档中给出了添加算子到这个域中的方法,这里就不再进行介绍了,以后用到了再说吧。

【推理引擎】如何在 ONNXRuntime 中添加新的算子的更多相关文章

  1. 如何在Linux中添加新的系统调用

    系统调用是应用程序和操作系统内核之间的功能接口.其主要目的是使得用户 可以使用操作系统提供的有关设备管理.输入/输入系统.文件系统和进程控制. 通信以及存储管理等方面的功能,而不必了解系统程序的内部结 ...

  2. 如何在Pycharm中添加新的模块

    在使用Pycharm编写程序时,我们时常需要调用某些模块,但有些模块事先是没有的,我们需要把模块添加上去. 最近在学习爬虫,写了下面几行代码: 结果出现错误 错误ModuleNotFoundError ...

  3. 如何在niosII中添加i2c外设_winday_新浪博客

    如何在niosII中添加i2c外设_winday_新浪博客 如何在niosII中添加i2c外设 winday 摘要:本文说明了如何在niosII添加第三方i2c外设,以供参考. 由于本人使用的Alte ...

  4. 如何在Eclipse中添加Tomcat的jar包

    原文:如何在Eclipse中添加Tomcat的jar包 右键项目工程,点击Java Build Path 点击Add Library,选择Server Runtime 选择Tomcat版本 此时就看到 ...

  5. 向CDH5集群中添加新的主机节点

    向CDH5集群中添加新的主机节点 步骤一:首先得在新的主机环境中安装JDK,关闭防火墙.修改selinux.NTP时钟与主机同步.修改hosts.与主机配置ssh免密码登录.保证安装好了perl和py ...

  6. (原)torch7中添加新的层

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6069627.html 参考网址: http://torch.ch/docs/developer-doc ...

  7. Windows Server 2008 如何在IIS中添加MIME类型

    用户可以通过使用MIME以设置服务器传送多媒体文件,如声音和视频等.MIME是一种技术规范,现在可以用于浏览器上,传送可以供浏览器识别的信息 如果我们的网站提供下载服务,有时传上去的文件比如 xxx. ...

  8. 如何在Eclipse中添加Servlet-api.jar的方法

    方法一: 点击窗口->首选项->java->构建路径->类路径变量->新建:将你的tomcat目录下的common/lib/servlet.jar加进来.如果你建立了一个 ...

  9. gitignore文件中添加新过滤文件,但是此文件已经提交,如何解决?

    gitignore文件中添加新过滤文件,但是此文件已经提交到远程库,如何解决? 第一步,为避免冲突需要先同步下远程仓库 git pull 第二步,在本地项目目录下删除缓存 git rm -r --ca ...

随机推荐

  1. CLR 详解

    公共语言运行时就是按照CLI标准制作的执行托管代码的环境.CLR 能运行非托管代码. 公共语言运行的功能:代码JIT/AOT编译.  内存管理 .垃圾回收.异常处理.反射服务.安全服务.程序集加载.本 ...

  2. SpringBoot进阶教程(七十三)整合elasticsearch

    Elasticsearch 是一个分布式.高扩展.高实时的搜索与数据分析引擎.它能很方便的使大量数据具有搜索.分析和探索的能力.充分利用Elasticsearch的水平伸缩性,能使数据在生产环境变得更 ...

  3. ComboBox控件绑定数据源后,添加'请选择'或'全部'

    ComboBox控件绑定数据源后,添加'请选择'或'全部' 当使用ComboBox控件绑定数据源之后,通过Items 属性添加的数据是无效的,此时如果要在所有选项前添加 选项 ,则需要考虑从数据源下手 ...

  4. SQL Server的Linked Servers

    文章搬运自:SQL Server的Linked Servers(链接) 参考引用一下,感谢作者~ 我们在使用SQL Server时,有时会有这种需求,需要从一个SQL Server服务器A中,查询另一 ...

  5. php 23种设计模型 - 装饰模式

    装饰器模式(Decorator) 装饰器模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其结构.这种类型的设计模式属于结构型模式,它是作为现有的类的一个包装. ...

  6. 《手把手教你》系列基础篇(七十五)-java+ selenium自动化测试-框架设计基础-TestNG实现DDT - 中篇(详解教程)

    1.简介 上一篇中介绍了DataProvider如何传递参数,以及和一些其他方法结合传递参数,今天宏哥接着把剩下的一些常用的也做一下简单的介绍和分享. 2.项目实战1 @DataProvider + ...

  7. LGP6694题解

    第一眼似乎很困难,实际上非常简单( 好吧这题我做了一个小时( 首先期望具有线性性,我们转化为计算点对对答案的贡献. 发现相对位置一样的点对对答案的贡献是一样的.我们把相对位置一样的点对铃出来,乘了之后 ...

  8. How Do Vision Transformers Work?[2202.06709] - 论文研读系列(2) 个人笔记

    [论文简析]How Do Vision Transformers Work?[2202.06709] 论文题目:How Do Vision Transformers Work? 论文地址:http:/ ...

  9. OpenCv基础_四

    Harris角点检测 理解 内部点:蓝框所示,无论滑动窗口水平滑动还是竖直滑动,框内像素值都不会发生大的变化 边界点:黑框所示,滑动窗口沿着某一个方向滑动框内像素点不会发生大的改变,但是沿着另一个方向 ...

  10. vue2.x版本中computed和watch的使用入门详解-computed篇

    前言 在基于vue框架的前端项目开发过程中,只要涉及到稍微复杂一点的业务,我们都会用到computed计算属性这个钩子函数,可以用于一些状态的结合处理和缓存的操作. 基础使用 在computed中,声 ...