摘要:简要介绍一下akg正反向算子的注册和关联流程。

本文分享自华为云社区《AKG正反向算子注册+关联》,作者:木子_007 。

一、环境

硬件:eulerosv2r8.aarch64

mindspore:1.1

算子注册需要编译安装框架才能生效,所以默认环境中已经有了mindspore的源码,并且已经可以编译安装

二、正向算子制作及测试

这里制作一个计算向量平方的算子

正向:y = x**2

反向:y = 2*x

先介绍正向

2.1 定义正向算子

路径:mindspore/akg/python/akg/ms/cce/,创建cus_square.py

参照同级目录下计算逻辑的定义,定义向量平方的计算逻辑

"""cus_square"""
from akg.tvm.hybrid import script
from akg.ops.math import mul
import akg
def CusSquare(x):
output_shape = x.shape
k = output_shape[0]
n = output_shape[1] @script
def cus_square_compute(x):
y = output_tensor(output_shape, dtype=x.dtype)
for i in range(k):
for j in range(n):
y[i, j] = x[i, j] * x[i, j]
return y output = cus_square_compute(x) attrs = {
'enable_post_poly_loop_partition': False,
'enable_double_buffer': False,
'enable_feature_library': True,
'RewriteVarTensorIdx': True
} return output, attrs

然后在同级目录下的__init__.py文件中添加内容

from .cus_square import CusSquare

2.2 注册算子

到路径:mindspore/ops/_op_impl/akg/ascend,创建cus_square.py,添加如下代码

"""CusSquare op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT op_info = AkgAscendRegOp("CusSquare") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _cus_square_akg():
"""CusSquare Akg register"""
return

然后在同级目录的__init__.py添加如下代码

from .cus_square import _cus_square_akg

2.3 定义算子原语

到:mindspore/ops/operations,新创建一个_cus_ops.py,添加如下代码

描述算子的输入:x,输出output

infer_shape:描述输出数据的shape

infer_dtype:说明输出数据的类型

x1_shape:指的是第一个输入的shape

x1_dtype:指的是第一个输入参数的dtype

import math

from ..primitive import prim_attr_register, PrimitiveWithInfer
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel class CusSquare(PrimitiveWithInfer):
"""CusSquare""" @prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output']) def infer_shape(self, x1_shape):
return x1_shape def infer_dtype(self, x1_dtype):
return x1_dtype

然后在同目录下的__init__.py文件中添加原语信息

from ._cus_ops import CusSquare

2.4 在ccsrc中添加算子的查询信息

在mindspore/ccsrc/backend/kernel_compiler/http://kernel_query.cc的KernelQuery函数中添加如下信息

// cus_square
const PrimitivePtr kPrimCusSquare = std::make_shared<Primitive>("CusSquare");
if (IsPrimitiveCNode(kernel_node, kPrimCusSquare)) {
kernel_type = KernelType::AKG_KERNEL;
}

2.5 编译安装框架

回到mindspore根目录

bash build.sh -e ascend -j4
cd ./build/package
pip install mindspore_ascend-1.1.2-cp37-cp37m-linux_aarch64.whl --force-reinstall

2.6 测试

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.square = P.CusSquare() def construct(self, data):
return self.square(data) def test_net():
x = np.array([[1.0, 4.0, 9.0]]).astype(np.float32)
net = Net()
output = net(Tensor(x))
print("x: ", x)
print("output: ", output)
if __name__ == "__main__":
test_net()

输出

三、反向算子的制作和测试

3.1 制作流程

反向算子的计算逻辑:对向量元素进行求导,如 y = x^2,则求导之后 y` = 2x

实际例子就是输入向量[1, 4, 9] 输出就是 [2, 8, 18]

反向算子明明为CusSquareGrad,与前边的计算平方的算子流程相同,这里只贴一下关键代码,流程不再赘述

计算逻辑代码cus_square_grad.py

"""cus_square_grad"""
from akg.tvm.hybrid import script
import akg def CusSquareGrad(x):
output_shape = x.shape
k = output_shape[0]
n = output_shape[1] @script
def cus_square_compute_grad(x):
y = output_tensor(output_shape, dtype=x.dtype)
for i in range(k):
for j in range(n):
y[i, j] = x[i, j] * 2
return y output = cus_square_compute_grad(x) attrs = {
'enable_post_poly_loop_partition': False,
'enable_double_buffer': False,
'enable_feature_library': True,
'RewriteVarTensorIdx': True
} return output, attrs

注册原语

class CusSquareGrad(PrimitiveWithInfer):
"""
CusSquareGrad
""" @prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output']) def infer_shape(self, x1_shape):
return x1_shape def infer_dtype(self, x1_dtype):
return x1_dtype

3.2 测试

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.square = P.CusSquareGrad() # 替换为grad算子 def construct(self, data):
return self.square(data) def test_net():
x = np.array([[1.0, 4.0, 9.0]]).astype(np.float32)
net = Net()
output = net(Tensor(x))
print("x: ", x)
print("output: ", output)
if __name__ == "__main__":
test_net()

输出

四、正反向算子关联及测试

在源码 mindspore/mindspore/ops/_grad/grad_array_ops.py中添加如下代码

@bprop_getters.register(P.CusSquare)
def get_bprop_cussquare(self):
"""Generate bprop of CusSquare"""
cus_square_grad = P.CusSquareGrad()
matmul = ops.Mul()
def bprop(x, out, dout):
gradient = cus_square_grad(x)
dx = matmul(gradient, dout)
return (dx,)
return bprop

bprop函数的输入是,正向的输入x,正向的输出out,反向的梯度输入dout

上面代码的意思是指定算子CusSquare的反向梯度的计算方法,CusSquareGrad作为其中的一个函数使用

gradient = cus_square_grad(x)计算的是本平方算子的梯度,但并不能直接返回这个梯度

反向网络到该算子,最后返回的是dx,注意算子的反向梯度计算一定要放在整个网络的反向链式梯度计算中

测试

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.square = P.CusSquare() def construct(self, data):
return self.square(data) def test_net():
x = Tensor(np.array([[1.0, 4.0, 9.0]]).astype(np.float32))
grad = C.GradOperation(get_all=True) # 计算网络梯度
net = Net()
output = grad(net)(x)
print("x: ", x)
print("output: ", output)
if __name__ == "__main__":
test_net()

输出

点击关注,第一时间了解华为云新鲜技术~

带你了解AKG正反向算子注册+关联流程的更多相关文章

  1. [源码分析] 带你梳理 Flink SQL / Table API内部执行流程

    [源码分析] 带你梳理 Flink SQL / Table API内部执行流程 目录 [源码分析] 带你梳理 Flink SQL / Table API内部执行流程 0x00 摘要 0x01 Apac ...

  2. 如何设计一个 App 的注册登录流程?

    移 动设备发力之前的登录方式很简单:用户名/邮箱+密码+确认密码,所有的用户登录注册都是围绕着邮箱来做.随着移动设备和社交网络的普及,邮箱不再是唯 一,渐渐的出现了微博,QQ,微信等第三方登录方式,手 ...

  3. 第三节:带你详解Java的操作符,控制流程以及数组

    前言 大家好,给大家带来带你详解Java的操作符,控制流程以及数组的概述,希望你们喜欢 操作符 算数操作符 一般的 +,-,*,/,还有两个自增 自减 ,以及一个取模 % 操作符. 这里的操作算法,一 ...

  4. Android Market google play store帐号注册方法流程 及发布应用注意事项

    Android Market google play store帐号申请 注册方法流程 在 Google Play 中发布软件之前,您需要完成以下三项工作: 创建开发人员个人资料 接受开发人员分发协议 ...

  5. Android Market google play store帐号注册方法流程 及发布应用注意事项【转载】

    [转载]http://www.cnblogs.com/zdz8207/archive/2012/07/09/google-play-store-registered.html Android Mark ...

  6. Nacos(二)源码分析Nacos服务端注册示例流程

    上回我们讲解了客户端配置好nacos后,是如何进行注册到服务器的,那我们今天来讲解一下服务器端接收到注册实例请求后会做怎么样的处理. 首先还是把博主画的源码分析图例发一下,让大家对整个流程有一个大概的 ...

  7. Spring Security 的注册登录流程

    Spring Security 的注册登录流程 数据库字段设计 主要数据库字段要有: 用户的 ID 用户名称 联系电话 登录密码(非明文) UserDTO对象 需要一个数据传输对象来将所有注册信息发送 ...

  8. 使用Microsoft自带的小工具将可执行文件(.exe)注册为系统服务

    首先,我们从Microsoft下载Windows Resource Kits,Download 下载完成后,运行rktools.exe进行安装. 安装完成后,我们打开安装目录,将其中的"in ...

  9. Nacos(一)源码分析Nacos注册示例流程

    nacos官方地址:https://nacos.io/zh-cn/ 大家可以看一下nacos的中文手册以及官方源码,博主就不带领大家快速入门 了,官方文档中都有而且非常标准,比其他博客写的好多了并且还 ...

  10. Spring Cloud Eureka源码分析之服务注册的流程与数据存储设计!

    Spring Cloud是一个生态,它提供了一套标准,这套标准可以通过不同的组件来实现,其中就包含服务注册/发现.熔断.负载均衡等,在spring-cloud-common这个包中,org.sprin ...

随机推荐

  1. SQL基础应用

    SQL基础应用 更多详细内容请查阅:https://www.jianshu.com/p/08c4b78402ff 1.SQL介绍 结构化查询语言 5.7 以后符合SQL92严格模式 通过sql_mod ...

  2. 题解 CF1401C

    题目大意: 给定一序列 \(A\),定义当且仅当 \(\gcd(a_i,a_j)=a_{min}\) 时,元素 \(a_i\) 和 \(a_j\) 可以交换. 问当前给定的序列 \(A\) 能否转化为 ...

  3. R数据分析:集成学习方法之随机生存森林的原理和做法,实例解析

    很久很久以前给大家写过决策树,非常简单明了的算法.今天给大家写随机(生存)森林,随机森林是集成了很多个决策数的集成模型.像随机森林这样将很多个基本学习器集合起来形成一个更加强大的学习器的这么一种集成思 ...

  4. 总结(3)--- 知识总结(内存管理、线程阻塞、GIL锁)

    一.Python中是如何进行内存管理的? 垃圾回收:Python不像C++,Java等语言一样,他们可以不用事先声明变量类型而直接对变量进行赋值.对Python而言,对象的类型和内存都是在运行时确定的 ...

  5. 【scipy 基础】--空间计算

    scipy.spatial子模块提供了一系列用于处理和计算空间数据和几何形状的算法和工具,在许多领域都有广泛的应用,例如计算机视觉.地理信息系统.机器人学.医学影像分析等. 下面,来具体看看scipy ...

  6. 前端解析excel表格

    需求如下: 前端拿到表格中的数据,对数据做以下判断,并将拿到的数据转换成以下json格式,传给后端. 具体实现: 下载npm包:npm install xlsx --save 在vue文件中引入依赖: ...

  7. Excel表格函数公式出现溢出怎么办?

    Excel是一款广泛使用的电子表格软件,它可以帮助我们进行各种计算.数据分析与处理等操作.在使用Excel时,我们通常需要使用到各种函数公式来完成不同的任务.然而,在使用函数公式时有时会出现" ...

  8. 聊聊大数据框架的数据更新策略: COW,MOR,MOW

    大数据框架下,常用的数据更新策略有三种: COW: copy-on-write, 写时复制; MOR: merge-on-read, 读时合并; MOW: merge-on-write, 写时合并; ...

  9. 开发环境搭建:CubeMX、Keil MDK-ARM、仿真器驱动程序

    来源:成电<微机原理与嵌入式系统>漆强 第三章 STM32微控制器开发环境的搭建 一.STM32 CubeMX的安装 1.STM32 CubeMX的下载和安装 先安装java环境安装 下载 ...

  10. Go 语言区块链测试:实践指南

    引言 Go 语言在区块链开发中的应用日益增多,凭借其简洁的语法和强大的并发支持,成为开发区块链应用的热门选择.理解和实践 Go 语言的单元测试对于保证区块链应用的质量和稳定性至关重要. Go 单元测试 ...