如何写新的Python OP
如何写新的Python OP
Paddle 通过 py_func 接口支持在Python端自定义OP。 py_func的设计原理在于Paddle中的Tensor可以与numpy数组可以方便的互相转换,从而可以使用Python中的numpy API来自定义一个Python OP。
py_func接口概述
py_func 具体接口为:
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
pass
其中,
- x 是Python Op的输入变量,可以是单个 Tensor | tuple[Tensor] | list[Tensor] 。多个Tensor以tuple[Tensor]或list[Tensor]的形式传入。
- out 是Python Op的输出变量,可以是单个 Tensor | tuple[Tensor] | list[Tensor],也可以是Numpy Array 。
- func 是Python Op的前向函数。在运行网络前向时,框架会调用 out = func(*x) ,根据前向输入 x 和前向函数 func 计算前向输出 out。在 func 建议先主动将Tensor转换为numpy数组,方便灵活的使用numpy相关的操作,如果未转换成numpy,则可能某些操作无法兼容。
- backward_func 是Python Op的反向函数。若 backward_func 为 None ,则该Python Op没有反向计算逻辑; 若 backward_func 不为 None,则框架会在运行网路反向时调用 backward_func 计算前向输入 x 的梯度。
- skip_vars_in_backward_input 为反向函数 backward_func 中不需要的输入,可以是单个 Tensor | tuple[Tensor] | list[Tensor] 。
如何使用py_func编写Python Op
以下以tanh为例,介绍如何利用 py_func 编写Python Op。
- 第一步:定义前向函数和反向函数
前向函数和反向函数均由Python编写,可以方便地使用Python与numpy中的相关API来实现一个自定义的OP。
若前向函数的输入为 x_1, x_2, …, x_n ,输出为y_1, y_2, …, y_m,则前向函数的定义格式为:
def foward_func(x_1, x_2, ..., x_n):
...
return y_1, y_2, ..., y_m
默认情况下,反向函数的输入参数顺序为:所有前向输入变量 + 所有前向输出变量 + 所有前向输出变量的梯度,因此对应的反向函数的定义格式为:
def backward_func(x_1, x_2, ..., x_n, y_1, y_2, ..., y_m, dy_1, dy_2, ..., dy_m):
...
return dx_1, dx_2, ..., dx_n
若反向函数不需要某些前向输入变量或前向输出变量,可设置 skip_vars_in_backward_input 进行排除(步骤三中会叙述具体的排除方法)。
注:,x_1, …, x_n为输入的多个Tensor,请以tuple(Tensor)或list[Tensor]的形式在py_func中传入。建议先主动将Tensor通过numpy.array转换为数组,否则Python与numpy中的某些操作可能无法兼容使用在Tensor上。
利用numpy的相关API完成tanh的前向函数和反向函数编写。下面给出多个前向与反向函数定义的示例:
import numpy as np
# 前向函数1:模拟tanh激活函数
def tanh(x):
# 可以直接将Tensor作为np.tanh的输入参数
return np.tanh(x)
# 前向函数2:将两个2-D Tenosr相加,输入多个Tensor以list[Tensor]或tuple(Tensor)形式
def element_wise_add(x, y):
# 必须先手动将Tensor转换为numpy数组,否则无法支持numpy的shape操作
x = np.array(x)
y = np.array(y)
if x.shape != y.shape:
raise AssertionError("the shape of inputs must be the same!")
result = np.zeros(x.shape, dtype='int32')
for i in range(len(x)):
for j in range(len(x[0])):
result[i][j] = x[i][j] + y[i][j]
return result
# 前向函数3:可用于调试正在运行的网络(打印值)
def debug_func(x):
# 可以直接将Tensor作为print的输入参数
print(x)
# 前向函数1对应的反向函数,默认的输入顺序为:x、out、out的梯度
def tanh_grad(x, y, dy):
# 必须先手动将Tensor转换为numpy数组,否则"+/-"等操作无法使用
return np.array(dy) * (1 - np.square(np.array(y)))
注意,前向函数和反向函数的输入均是 Tensor 类型,输出可以是Numpy Array或 Tensor。 由于 Tensor 实现了Python的buffer protocol协议,因此即可通过 numpy.array 直接将 Tensor 转换为numpy Array来进行操作,也可直接将 Tensor 作为numpy函数的输入参数。但建议先主动转换为numpy Array,则可以任意的使用python与numpy中的所有操作(例如”numpy array的+/-/shape”)。
tanh的反向函数不需要前向输入x,因此我们可定义一个不需要前向输入x的反向函数,并在后续通过 skip_vars_in_backward_input 进行排除 :
def tanh_grad_without_x(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
- 第二步:创建前向输出变量
需调用 Program.current_block().create_var 创建前向输出变量。在创建前向输出变量时,必须指明变量的名称name、数据类型dtype和维度shape。
import paddle
paddle.enable_static()
def create_tmp_var(program, name, dtype, shape):
return program.current_block().create_var(name=name, dtype=dtype, shape=shape)
in_var = paddle.static.data(name='input', dtype='float32', shape=[-1, 28, 28])
# 手动创建前向输出变量
out_var = create_tmp_var(paddle.static.default_main_program(), name='output', dtype='float32', shape=[-1, 28, 28])
- 第三步:调用 py_func 组建网络
py_func 的调用方式为:
paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad)
若我们不希望在反向函数输入参数中出现前向输入,则可使用 skip_vars_in_backward_input 进行排查,简化反向函数的参数列表。
paddle.static.nn.py_func(func=tanh, x=in_var, out=out_var, backward_func=tanh_grad_without_x,
skip_vars_in_backward_input=in_var)
至此,使用 py_func 编写Python Op的步骤结束。可以与使用其他Op一样进行网路训练/预测。
注意事项
- py_func 的前向函数和反向函数内部不应调用 paddle.xx组网接口 ,因为前向函数和反向函数是在网络运行时调用的,而 paddle.xx 是在组建网络的阶段调用 。
- skip_vars_in_backward_input 只能跳过前向输入变量和前向输出变量,不能跳过前向输出的梯度。
- 若某个前向输出变量没有梯度,则 backward_func 将接收到 None 的输入。若某个前向输入变量没有梯度,则我们应在 backward_func 中主动返回 None。
如何写新的Python OP的更多相关文章
- 如何写新的C++ OP
如何写新的C++ OP 概念简介 简单介绍需要用到基类,详细介绍请参考设计文档. framework::OperatorBase: Operator(简写,Op)基类. framework::OpKe ...
- 一款新的PYTHON数据科学利器:yhat
偶然看文章发现的一个新的python IDE,外表很清新,也很新颖. https://www.yhat.com/products/rodeo 看图说话,软件的布局确实很新颖,和Notebook类似,也 ...
- 萌新学习Python爬取B站弹幕+R语言分词demo说明
代码地址如下:http://www.demodashi.com/demo/11578.html 一.写在前面 之前在简书首页看到了Python爬虫的介绍,于是就想着爬取B站弹幕并绘制词云,因此有了这样 ...
- 大数据萌新的Python学习之路(三)
笔记内容: 一.集合及其运算 在之列表中我们可以存储数据,并且对数据进行各种各样的操作.但是如果我们想要对数据进行去重时是十分麻烦的,需要使用循环,要建立新的列表,还要 进行对比,十分的麻烦,还消耗 ...
- Python学习宝典,Python400集让你成为从零基础到手写神经网络的Python大神
当您学完Python,你学到了什么? 开发网站! 或者, 基础语法要点.函数.面向对象编程.调试.IO编程.进程与线程.正则表达式... 当你学完Python,你可以干什么? 当程序员! 或者, 手写 ...
- 人人都可以写的可视化Python小程序第二篇:旋转的烟花
兴趣是最好的老师 枯燥的编程容易让人放弃,兴趣才是最好的老师.无论孩子还是大人,只有发现这件事情真的有趣,我们才会非常执着的去做这件事,比如打游戏.如果编程能像玩游戏一样变得有趣,我相信很多人就特别愿 ...
- 代码这样写更优雅(Python版)
要写出 Pythonic(优雅的.地道的.整洁的)代码,还要平时多观察那些大牛代码,Github 上有很多非常优秀的源代码值得阅读,比如:requests.flask.tornado,笔者列举一些常见 ...
- 大数据萌新的Python学习之路(一)
笔记开始简介 从2018年9月份正式进入大学的时代,大数据和人工智能的崛起让我选择了计算机专业学习数据科学与大数据技术专业,接触的第一门语言就是C语言,后来因为同学推荐的原因进入了学校的人工智能研究协 ...
- 大数据萌新的Python学习之路(二)
笔记内容: 一.模块 Python越来越被广大程序员使用,越来越火爆的原因是因为Python有非常丰富和强大标准库和第三方库,几乎可以实现你所想要实现的任何功能,并且都有相应的Python库支持,比如 ...
随机推荐
- img 的data-src 属性及懒加载
一.什么是图片懒加载 当访问一个页面的时候,先把img元素或是其他元素的背景图片路径替换成一张大小为1*1px图片的路径(这样就只需请求一次),当图片出现在浏览器的可视区域内时,才设置图片真正的路径, ...
- LA4851餐厅(求好的坐标的个数)
题意: 有一个m*m的格子,左下角(0,0)右上角(m-1,m-1),网格里面有两个y坐标相同的宾馆(A,B),每个宾馆里面有一个餐厅,一共用n个餐厅,第1,2个都在宾馆里,3,4...在 ...
- adbi学习:安装和使用
adbi 是一个android平台(arm 32 )的so注入+挂钩框架,源码开放在github上 : ADBI 项目 .从github上下载来目录如下: 执行主目录下build.sh编译后目录如下 ...
- hdu4912 LCA+贪心
题意: 给你一棵树和m条边,问你在这些边里面最多能够挑出多少条边,使得这些边之间不能相互交叉. 思路: lca+贪心,首先对于给的每个条边,我们用lca求出他们的公共节点,然后在 ...
- 内核模式下的线程同步的分析(Windows核心编程)
内核模式下的线程同步 内核模式下的线程同步是用户模式下的线程同步的扩展,因为用户模式下的线程同步有一定的局限性.但用户模式下线程同步的好处是速度快,不需要切换到内核模式(需要额外的 CPU 时间).通 ...
- [CTF]维吉尼亚密码(维基利亚密码)
[CTF]维吉尼亚密码(维基利亚密码) ----------------------百度百科 https://baike.baidu.com/item/维吉尼亚密码/4905472?fr=aladdi ...
- Object划分
Object划分 1.PO(persistantobject)持久对象 PO就是对应数据库中某个表中的一条记录,多个记录可以用PO的集合.PO中应该不包 含任何对数据库的操作. 2.DO(Domain ...
- Mac FTP工具推荐-Transmit
Transmit 是专为mac用户设计的一款功能强大的FTP客户端,Transmit5 mac兼容于FTP,SFTP和TLS/SSL协议,提供比Finder更加迅速的iDisk账户接入.与此同时,用户 ...
- 【我给面试官画饼】Python自动化测试面试题精讲
那今天给家分享的是一个面试主题. 就比如说我们的自动化测试,自动化如何去应对面试官,和面试官去聊一聊自动化的心得,自动化你现在去面试的时候是一个非常重要的一个关键点,所以如果你在这方面有一定的心得.那 ...
- Python数模笔记-(1)NetworkX 图的操作
1.NetworkX 图论与网络工具包 NetworkX 是基于 Python 语言的图论与复杂网络工具包,用于创建.操作和研究复杂网络的结构.动力学和功能. NetworkX 可以以标准和非标准的数 ...