如何写新的Python OP

Paddle 通过 py_func 接口支持在Python端自定义OP。 py_func的设计原理在于Paddle中的Tensor可以与numpy数组可以方便的互相转换,从而可以使用Python中的numpy API来自定义一个Python OP。

py_func接口概述

py_func 具体接口为:

def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):

pass

其中,

  • x 是Python Op的输入变量,可以是单个 Tensor | tuple[Tensor] | list[Tensor] 。多个Tensor以tuple[Tensor]或list[Tensor]的形式传入。
  • out 是Python Op的输出变量,可以是单个 Tensor | tuple[Tensor] | list[Tensor],也可以是Numpy Array 。
  • func 是Python Op的前向函数。在运行网络前向时,框架会调用 out = func(*x) ,根据前向输入 x 和前向函数 func 计算前向输出 out。在 func 建议先主动将Tensor转换为numpy数组,方便灵活的使用numpy相关的操作,如果未转换成numpy,则可能某些操作无法兼容。
  • backward_func 是Python Op的反向函数。若 backward_func 为 None ,则该Python Op没有反向计算逻辑; 若 backward_func 不为 None,则框架会在运行网路反向时调用 backward_func 计算前向输入 x 的梯度。
  • skip_vars_in_backward_input 为反向函数 backward_func 中不需要的输入,可以是单个 Tensor | tuple[Tensor] | list[Tensor] 。

如何使用py_func编写Python Op

以下以tanh为例,介绍如何利用 py_func 编写Python Op。

  • 第一步:定义前向函数和反向函数

前向函数和反向函数均由Python编写,可以方便地使用Python与numpy中的相关API来实现一个自定义的OP。

若前向函数的输入为 x_1, x_2, …, x_n ,输出为y_1, y_2, …, y_m,则前向函数的定义格式为:

def foward_func(x_1, x_2, ..., x_n):

...

return y_1, y_2, ..., y_m

默认情况下,反向函数的输入参数顺序为:所有前向输入变量 + 所有前向输出变量 + 所有前向输出变量的梯度,因此对应的反向函数的定义格式为:

def backward_func(x_1, x_2, ..., x_n, y_1, y_2, ..., y_m, dy_1, dy_2, ..., dy_m):

...

return dx_1, dx_2, ..., dx_n

若反向函数不需要某些前向输入变量或前向输出变量,可设置 skip_vars_in_backward_input 进行排除(步骤三中会叙述具体的排除方法)。

注:,x_1, …, x_n为输入的多个Tensor,请以tuple(Tensor)或list[Tensor]的形式在py_func中传入。建议先主动将Tensor通过numpy.array转换为数组,否则Python与numpy中的某些操作可能无法兼容使用在Tensor上。

利用numpy的相关API完成tanh的前向函数和反向函数编写。下面给出多个前向与反向函数定义的示例:

import numpy as np

# 前向函数1:模拟tanh激活函数

def tanh(x):

# 可以直接将Tensor作为np.tanh的输入参数

return np.tanh(x)

# 前向函数2:将两个2-D Tenosr相加,输入多个Tensor以list[Tensor]或tuple(Tensor)形式

def element_wise_add(x, y):

# 必须先手动将Tensor转换为numpy数组,否则无法支持numpy的shape操作

x = np.array(x)

y = np.array(y)

if x.shape != y.shape:

raise AssertionError("the shape of inputs must be the same!")

result = np.zeros(x.shape, dtype='int32')

for i in range(len(x)):

for j in range(len(x[0])):

result[i][j] = x[i][j] + y[i][j]

return result

# 前向函数3:可用于调试正在运行的网络(打印值)

def debug_func(x):

# 可以直接将Tensor作为print的输入参数

print(x)

# 前向函数1对应的反向函数,默认的输入顺序为:x、out、out的梯度

def tanh_grad(x, y, dy):

# 必须先手动将Tensor转换为numpy数组,否则"+/-"等操作无法使用

return np.array(dy) * (1 - np.square(np.array(y)))

注意,前向函数和反向函数的输入均是 Tensor 类型,输出可以是Numpy Array或 Tensor。 由于 Tensor 实现了Python的buffer protocol协议,因此即可通过 numpy.array 直接将 Tensor 转换为numpy Array来进行操作,也可直接将 Tensor 作为numpy函数的输入参数。但建议先主动转换为numpy Array,则可以任意的使用python与numpy中的所有操作(例如”numpy array的+/-/shape”)。

tanh的反向函数不需要前向输入x,因此我们可定义一个不需要前向输入x的反向函数,并在后续通过 skip_vars_in_backward_input 进行排除 :

def tanh_grad_without_x(y, dy):

return np.array(dy) * (1 - np.square(np.array(y)))

  • 第二步:创建前向输出变量

需调用 Program.current_block().create_var 创建前向输出变量。在创建前向输出变量时,必须指明变量的名称name、数据类型dtype和维度shape。

import paddle

paddle.enable_static()

def create_tmp_var(program, name, dtype, shape):

return program.current_block().create_var(name=name, dtype=dtype, shape=shape)

in_var = paddle.static.data(name='input', dtype='float32', shape=[-1, 28, 28])

# 手动创建前向输出变量

out_var = create_tmp_var(paddle.static.default_main_program(), name='output', dtype='float32', shape=[-1, 28, 28])

  • 第三步:调用 py_func 组建网络

py_func 的调用方式为:

paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad)

若我们不希望在反向函数输入参数中出现前向输入,则可使用 skip_vars_in_backward_input 进行排查,简化反向函数的参数列表。

paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad_without_x,

skip_vars_in_backward_input=in_var)

至此,使用 py_func 编写Python Op的步骤结束。可以与使用其他Op一样进行网路训练/预测。

注意事项

  • py_func 的前向函数和反向函数内部不应调用 paddle.xx组网接口 ,因为前向函数和反向函数是在网络运行时调用的,而 paddle.xx 是在组建网络的阶段调用 。
  • skip_vars_in_backward_input 只能跳过前向输入变量和前向输出变量,不能跳过前向输出的梯度。
  • 若某个前向输出变量没有梯度,则 backward_func 将接收到 None 的输入。若某个前向输入变量没有梯度,则我们应在 backward_func 中主动返回 None。

如何写新的Python OP的更多相关文章

  1. 如何写新的C++ OP

    如何写新的C++ OP 概念简介 简单介绍需要用到基类,详细介绍请参考设计文档. framework::OperatorBase: Operator(简写,Op)基类. framework::OpKe ...

  2. 一款新的PYTHON数据科学利器:yhat

    偶然看文章发现的一个新的python IDE,外表很清新,也很新颖. https://www.yhat.com/products/rodeo 看图说话,软件的布局确实很新颖,和Notebook类似,也 ...

  3. 萌新学习Python爬取B站弹幕+R语言分词demo说明

    代码地址如下:http://www.demodashi.com/demo/11578.html 一.写在前面 之前在简书首页看到了Python爬虫的介绍,于是就想着爬取B站弹幕并绘制词云,因此有了这样 ...

  4. 大数据萌新的Python学习之路(三)

    笔记内容:  一.集合及其运算 在之列表中我们可以存储数据,并且对数据进行各种各样的操作.但是如果我们想要对数据进行去重时是十分麻烦的,需要使用循环,要建立新的列表,还要 进行对比,十分的麻烦,还消耗 ...

  5. Python学习宝典,Python400集让你成为从零基础到手写神经网络的Python大神

    当您学完Python,你学到了什么? 开发网站! 或者, 基础语法要点.函数.面向对象编程.调试.IO编程.进程与线程.正则表达式... 当你学完Python,你可以干什么? 当程序员! 或者, 手写 ...

  6. 人人都可以写的可视化Python小程序第二篇:旋转的烟花

    兴趣是最好的老师 枯燥的编程容易让人放弃,兴趣才是最好的老师.无论孩子还是大人,只有发现这件事情真的有趣,我们才会非常执着的去做这件事,比如打游戏.如果编程能像玩游戏一样变得有趣,我相信很多人就特别愿 ...

  7. 代码这样写更优雅(Python版)

    要写出 Pythonic(优雅的.地道的.整洁的)代码,还要平时多观察那些大牛代码,Github 上有很多非常优秀的源代码值得阅读,比如:requests.flask.tornado,笔者列举一些常见 ...

  8. 大数据萌新的Python学习之路(一)

    笔记开始简介 从2018年9月份正式进入大学的时代,大数据和人工智能的崛起让我选择了计算机专业学习数据科学与大数据技术专业,接触的第一门语言就是C语言,后来因为同学推荐的原因进入了学校的人工智能研究协 ...

  9. 大数据萌新的Python学习之路(二)

    笔记内容: 一.模块 Python越来越被广大程序员使用,越来越火爆的原因是因为Python有非常丰富和强大标准库和第三方库,几乎可以实现你所想要实现的任何功能,并且都有相应的Python库支持,比如 ...

随机推荐

  1. hdu2482 字典树+spfa

    题意:       给你一个地图,地图上有公交站点和路线,问你从起点到终点至少要换多少次公交路线. 思路:       首先上面的题意说的和笼统,没说详细是因为这个题目叙述的很多,描述起来麻烦, 下面 ...

  2. LA4851餐厅(求好的坐标的个数)

    题意:       有一个m*m的格子,左下角(0,0)右上角(m-1,m-1),网格里面有两个y坐标相同的宾馆(A,B),每个宾馆里面有一个餐厅,一共用n个餐厅,第1,2个都在宾馆里,3,4...在 ...

  3. hdu4932 小贪心

    题意:      给了一些处在x轴上的点,要求我们用长度相等的线段覆盖所有点,线段和线段之间不能重叠,问线段最长可以使多长. 思路:       一开始一直在想二分,哎!感觉这个题目很容易就往二分上去 ...

  4. Python中sys模块的使用

    目录 sys模块 sys.argv() sys.exit(0) sys.path sys.modules sys模块负责程序与python解释器的交互,提供了一系列的函数和变量,用于操控python的 ...

  5. 日志框架整合报错Class path contains multiple SLF4J bindings.

    在进行SSM框架的日志框架统一管理时,报错Class path contains multiple SLF4J bindings 如下图 意思是类路径下包含重复的SLF4J绑定,然后给出了重复的两个全 ...

  6. PHP基础—PHP的数据类型与常量使用

  7. SpringBoot简单尝试

    一.spring boot核心 配置在类路径下autoconfigure下(多瞅瞅) @SpringBootApplication里的重要注解(@Configuration,@EnableAutoCo ...

  8. 『动善时』JMeter基础 — 17、JMeter配置元件【HTTP请求默认值】

    目录 1.HTTP请求默认值介绍 2.HTTP请求默认值界面 3.HTTP请求默认值的使用 (1)用于演示的项目说明 (2)测试计划内包含的元件 (3)说明HTTP请求默认值用法 4.总结 5.拓展知 ...

  9. 为什么数字被int格式化后依旧可以用%s占位(勉强已答)

    为什么数字被int格式化后依旧可以用%s占位 答:可以看作str(obj)

  10. [源码分析] 定时任务调度框架 Quartz 之 故障切换

    [源码分析] 定时任务调度框架 Quartz 之 故障切换 目录 [源码分析] 定时任务调度框架 Quartz 之 故障切换 0x00 摘要 0x01 基础概念 1.1 分布式 1.1.1 功能方面 ...