DLPack构建跨框架的深度学习编译器

Tensorflow,PyTorch和ApacheMxNet等深度学习框架提供了一个功能强大的工具包,可用于快速进行原型设计和部署深度学习模型。易用性通常是以碎片为代价的:孤立地使用每个框架是很容易的。垂直集成已使常见用例的开发流程简化了,但是冒险走过的路可能很棘手。

一个支持不佳的方案是将张量直接从一个框架传递到内存中的另一个框架,而没有任何数据重复或复制。支持这种用例使用户能够将管道串联在一起,其中某些算子在一个框架中得到比在另一个框架中得到更好的支持(或更快速)。框架之间共享的数据表示形式也将弥合这一差距,并在为算子生成代码时,允许编译器堆栈以单一格式为目标。

DLPack是用于张量数据结构的中间内存表示标准。使用DLPack作为通用表示,传统上只能依赖供应商提供的库的框架编写的脚本中利用TVM。TVM打包函数可以在DLPack张量上运行,提供包装程序以桥接带有零数据副本的框架(例如PyTorch和MxNet)中的张量数据结构。

DLPack提供了一种简单的可移植内存数据结构:

/*!
 * \brief Plain C Tensor object, does not manage memory.
 */
typedef struct {
  /*!
   * \brief The opaque data pointer points to the allocated data.
   *  This will be CUDA device pointer or cl_mem handle in OpenCL.
   *  This pointer is always aligns to 256 bytes as in CUDA.
   */
  void* data;
  /*! \brief The device context of the tensor */
  DLContext ctx;
  /*! \brief Number of dimensions */
  int ndim;
  /*! \brief The data type of the pointer*/
  DLDataType dtype;
  /*! \brief The shape of the tensor */
  int64_t* shape;
  /*!
   * \brief strides of the tensor,
   *  can be NULL, indicating tensor is compact.
   */
  int64_t* strides;
  /*! \brief The offset in bytes to the beginning pointer to data */
  uint64_t byte_offset;
} DLTensor;

例如,在TVM中声明并编译一个矩阵乘法算子,并构建一个使用DLPack表示形式的包装器wrapper,允许该算子支持PyTorch张量。还使用MxNet重复此演示。此扩展使机器学习开发人员可以在不牺牲性能的情况下,将代码快速移植到相对不受支持的硬件平台上。

DLPack如何提供框架和TVM之间共享的中间包wrapper的说明:

图1

首先,在PyTorch中计算参考输出:

    import torch
    x = torch.rand(56,56)
    y = torch.rand(56,56)
    z = x.mm(y)

然后,使用默认调度定义并构建TVM矩阵乘法算子:

    n = tvm.convert(56)
    X = tvm.placeholder((n,n), name='X')
    Y = tvm.placeholder((n,n), name='Y')
 
    k = tvm.reduce_axis((0, n), name='k')
    Z = tvm.compute((n,n), lambda i,j : tvm.sum(X[i,k]*Y[k,j], axis=k))
    s = tvm.create_schedule(Z.op)
    fmm = tvm.build(s, [X, Y, Z], target_host='llvm', name='fmm')

为简便起见,没有涵盖可用于优化矩阵乘法的TVM大量的调度原语集合。如果希望使自定义GEMM算子在的硬件设备上快速运行,请参考详细的教程。

然后,将TVM函数转换为支持PyTorch张量的函数:

    from tvm.contrib.dlpack import to_pytorch_func
    # fmm is the previously built TVM function (Python function)
    # fmm is the wrapped TVM function (Python function)
    fmm_pytorch = to_pytorch_func(fmm)
    z2 = torch.empty(56,56)
    fmm_pytorch(x, y, z2)
    np.testing.assert_allclose(z.numpy(), z2.numpy())

并验证结果是否匹配。

可以重复相同的示例,但是使用MxNet代替:

    import mxnet
    from tvm.contrib.mxnet import to_mxnet_func
    ctx = mxnet.cpu(0)
    x = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
    y = mxnet.nd.uniform(shape=(56,56), ctx=ctx)
    z = mxnet.nd.empty(shape=(56,56), ctx=ctx)
    f = tvm.build(s, [X, Y, Z], target_host='llvm', name='f')
    f_mxnet = to_mxnet_func(f)
    f_mxnet(x, y, z)
    np.testing.assert_allclose(z.asnumpy(), x.asnumpy().dot(y.asnumpy()))

在PyTorch示例的幕后

由于TVM提供了将dlpack张量转换为tvm的功能NDArray反之亦然,因此,通过wrapper功能,所需的只是一些语法 syntactic sugar 。 convert_func是用于使用具有dlpack支持的张量的框架的通用转换器,可以用于实现方便的转换器,例如 to_pytorch_func

def convert_func(tvm_func, tensor_type, to_dlpack_func):
    assert callable(tvm_func)
 
    def _wrapper(*args):
        args = tuple(ndarray.from_dlpack(to_dlpack_func(arg))\
            if isinstance(arg, tensor_type) else arg for arg in args)
        return tvm_func(*args)
 
    return _wrapper
 
def to_pytorch_func(tvm_func):
    import torch
    import torch.utils.dlpack
    return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)

DLPack构建跨框架的深度学习编译器的更多相关文章

  1. 通过 DLPack 构建跨框架深度学习编译器

    通过 DLPack 构建跨框架深度学习编译器 深度学习框架,如Tensorflow, PyTorch, and ApacheMxNet,快速原型化和部署深度学习模型提供了强大的工具箱.不幸的是,易用性 ...

  2. go微服务框架go-micro深度学习-目录

    go微服务框架go-micro深度学习(一) 整体架构介绍 go微服务框架go-micro深度学习(二) 入门例子 go微服务框架go-micro深度学习(三) Registry服务的注册和发现 go ...

  3. go微服务框架go-micro深度学习(四) rpc方法调用过程详解

    上一篇帖子go微服务框架go-micro深度学习(三) Registry服务的注册和发现详细解释了go-micro是如何做服务注册和发现在,服务端注册server信息,client获取server的地 ...

  4. go微服务框架go-micro深度学习 rpc方法调用过程详解

    摘要: 上一篇帖子go微服务框架go-micro深度学习(三) Registry服务的注册和发现详细解释了go-micro是如何做服务注册和发现在,服务端注册server信息,client获取serv ...

  5. 大数据下基于Tensorflow框架的深度学习示例教程

    近几年,信息时代的快速发展产生了海量数据,诞生了无数前沿的大数据技术与应用.在当今大数据时代的产业界,商业决策日益基于数据的分析作出.当数据膨胀到一定规模时,基于机器学习对海量复杂数据的分析更能产生较 ...

  6. go微服务框架go-micro深度学习(二) 入门例子

    上一篇帖子简单介绍了go-micro的整体框架结构,这一篇主要写go-micro使用方式的例子,中间会穿插一些go-micro的源码,和调用流程图,帮大家更好的理解go-micro的底层.更详细更具体 ...

  7. go微服务框架go-micro深度学习(一) 整体架构介绍

    产品嘴里的一个小项目,从立项到开发上线,随着时间和需求的不断激增,会越来越复杂,变成一个大项目,如果前期项目架构没设计的不好,代码会越来越臃肿,难以维护,后期的每次产品迭代上线都会牵一发而动全身.项目 ...

  8. go微服务框架go-micro深度学习(三) Registry服务的注册和发现

    服务的注册与发现是微服务必不可少的功能,这样系统才能有更高的性能,更高的可用性.go-micro框架的服务发现有自己能用的接口Registry.只要实现这个接口就可以定制自己的服务注册和发现. go- ...

  9. go微服务框架go-micro深度学习(五) stream 调用过程详解

        上一篇写了一下rpc调用过程的实现方式,简单来说就是服务端把实现了接口的结构体对象进行反射,抽取方法,签名,保存,客户端调用的时候go-micro封请求数据,服务端接收到请求时,找到需要调用调 ...

随机推荐

  1. hdu1839 二分最短路

    题意:       给你n个城市,m条双向边,每条边有自己的长度和最大运输量,让你找到一条时间小于等于T的运输能力最大的那条路... 思路:       刚开始以为是费用流呢,后来发现根本不是,因为根 ...

  2. hdu4861 找规律了

    题意:      给你k个球和一个整数p,每个球的价值是 1^i+2^i+...+(p-1)^i (mod p),两个人轮流取球,最后谁的总价值也大谁就赢,问你先手能不能赢. 思路:      一开始 ...

  3. C++处理char*,char[],string三种类型间的转换

    前言 在C和C++中,有一个相当重要的部分,就是字符串的编程描述.在学C的时候,很多人习惯了char[],char*表示法,直到遇见了C++后,出现了第三者:string.这时候,很多初学者就会在这三 ...

  4. 010 Editor体验

    源代码的我们现在拥有各式各样的IDE和编辑器可以去查看,但二进制文件对于大多数软件只能做到显示16进制,而不能按照文件类型的格式去显示.今天我们就用dex文件让010 show. 安装软件: http ...

  5. 还在一个模块打天下嘛?你知道引入Jetpack架构后,你的App会发生哪些奇妙的变化吗?

    前言 上篇文章我给大家分享了我对Android架构的理解,从思想层面去讲述架构的演进过程.很多小伙伴读完后拍手叫好,表示还想听我讲一下对Jetpack 架构的看法,本着帮人帮到底的精神,今天我将再次动 ...

  6. 前端Excel表格导入导出,包括合并单元格,表格自定义样式等

    表格数据导入 读取导入Excel表格数据这里采用的是 xlsx 插件 npm i xlsx 读取excel需要通过 XLSX.read(data, {type: type}) 方法来实现,返回一个叫W ...

  7. iwrite复制攻略

    打开iwrite,一提交作业,发现: 这可咋办啊! 那就跟着步骤来呗: 按F12打开元素审查 点一下左上角 再点一下文本框,就能定位到HTML中的位置 在文本框中写几个字母,康康具体位置: 那就复制进 ...

  8. 解决@Autowired警告

    在使用spring框架中的依赖注入注解@Autowired时,idea报了一个警告 被警告的代码如下: @Autowired UserDao userDao; 警告提示信息:Field injecti ...

  9. 二、jmeter模拟请求头及监听器之结果树

    一.模拟请求头 利用jmeter发送http请求时,被接收的服务端会对发送的该请求进行初步判断,如果不是web端发送的请求就会被打回导致请求不通,这时候需要模拟请求头,模拟正常的用户行为进行发送请求 ...

  10. PageHelper简单使用

    PageHelper的简单使用 先引入对应的依赖 <dependency> <groupId>com.github.pagehelper</groupId> < ...