How to create own operator with python in mxnet?
继承CustomOp
- 定义操作符,重写前向后向方法,此时可以通过
_init__
方法传递需要用到的参数
class LossLayer(mxnet.operator.CustomOp):
def __init__(self, *args, **kwargs):
super(LossLayer, self).__init__()
# recipe some arguments for forward or backward calculation def forward(self, is_train, req, in_data, out_data, aux):
"""
in_data是一个列表,其中tensor的顺序和对应属性类中定义的list_arguments()参数一一对应
out_data输出列表
is_train 是否是训练过程
req [Null, write or inplace, add]指如何处理对应的复制操作
"""
pass
# 函数最后一般调用父类的self.assign(dst, req[0], src)进行赋值操作
# 但对于dst或者src是list类型的时候要调用多次assign函数处理,此时也可以直接自己赋值
# dst[:]=src def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
"""
out_grad 上一层反传的误差
in_data 输入数据,list
out_data 输出的数据,由forward方法确定, 其类型大小和out_grad一致
in_grad 需要计算的回传误差
"""
pass
# 其操作值得复制操作类似于forward方法
- 定义好操作符之后还需要定义其对应的属性类,并将其注册到operator中
@mx.operator.register('losslayer') # 注意这里注册的名字将是后面调用该操作符使用的类型名
- 重写对应的属性类
class LossLayerProp(mx.operator.CustomOpProp): # 这里的名字并非必须对应操作类名称,被@修饰符修饰
def __init__(self, params):
super(LossLayerProp,self).__init__(need_top_grad=False)
# 最后的损失层不需要接收上层的误差,则将need_top_grad设置为False
# 可以传递一些参数用以传递给操作类 def list_arguments(self):
# 这个方法非常重要,定义了该操作符的输入参数,当绑定对应操作符时,输入量由该方法指定
return ['data1','data2','data3','label'] def list_outputs(self):
# 同样返回的是列表,表示输出的量,这个其实是输出变量的后缀suffix
# 若返回的是['output1','output2']则输出为 操作类的名称name加上对应后缀的量[name_output1, name_output2]
return ['output'] def infer_shape(self, in_shape):
# 给定in_shape,显示每一个变量的对应大小,以判断大小是否一致
return [],[],[]
# 返回的必须是3个列表,即使列表为空,分别对应着输入参数的大小、输出数据的大小、aux参数的大小,一般最后一个为空 def infer_type(self, in_type):
# 该方法类似于infer_shape,推断数据类型 def create_operator(self, ctx, shapes, dtypes):
# 该方法真正的创建操作类对象,默认调用
return LossLayer()
- 自定义操作符的使用
data1=mx.sym.Variable('data1')
data2=mx.sym.Variable('data2')
data3=mx.sym.Variable('data3')
label = mx.sym.Variable('label')
# 下面这句调用很重要,显示指定输入的symbol,然后指定自定义操作符类型
net = mx.sym.Custom(data1=data1, data2=data2, data3=data3, label=label, name='net', op_type='losslayer')
# 输出操作符的相关属性
print(net.infer_shape(data1=(4,1,10,10), data2=(4,1,10,10),data3=(4,1,10,10) label=(4,)))
# data1=(4,1,10,10)表示对应symbol的shape
print(net.infer_type(data1=np.int, data2=np.int, data3=np.int, label=np.int))
# data1=np.int 标识对应symbol的数据类型
print(net.list_arguments()) # 变量参数
print(net.list_outputs()) #输出的变量参数 ex = net.simple_bind(ctx=mx.gpu(0), data1=(4,1,10,10), data2=(4,1,10,10),data3=(4,1,10,10) label=(4,)) # simple_bind只需要指定输入参数的大小
ex.forward(data1=data1, data2=data2, label=label))
print(ex.outputs[0])
- 上面是没有参数的层,创建带有参数的中间层和上面类似, 只是修改下面部分代码
def list_arguments(self):
return ['data','weight', 'bias'] def infer_shape(self, in_shape):
data_shape = in_shape[0]
weight_shape = ...
bias_shape = ...
output_shape = ...
return [data_shape, weight_shape, bias_shape], [output_shape], []
调用方式:
net = mx.symbol.Custom(data, name='newLayer', op_type='myLayer')
包含参数的layer在定义backward方法时要注意梯度的更新方式,即req的选择
NOTE:
有参数的操作符中,一般使用‘weight’和‘bias’作为参数, 该参数会最为后缀加到 opname_weight, opname_bias中,因为mxnet默认的参数初始化方法只认‘weight’, 'bias', 'gamma', 'beta'四个量, 对于自己新定义的量,比如weight2, 需要指定初始化方法
Default initialization is now limited to "weight", "bias", "gamma" (1.0), and "beta" (0.0).
Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern
How to create own operator with python in mxnet?的更多相关文章
- error: could not create '/System/Library/Frameworks/Python.framework/Versions/2.7/share': Operation not permitted
参考: Python pip安装模块报错 Mac升级到EI Captain之后pip install 无法使用问题 error: could not create '/System/Library/F ...
- Create your first isolated Python environment
# Install virtualenv for Python 2.7 and create a sandbox called my27project: pip2. install virtualen ...
- [Python] Object spread operator in Python
In JS, we have object spread opreator: const x = { a: '1', b: '2' } const y = { c: '3', d: '4' } con ...
- 使用python创建mxnet操作符(网络层)
对cuda了解不多,所以使用python创建新的操作层是个不错的选择,当然这个性能不如cuda编写的代码. 在MXNET源码的example/numpy-ops/下有官方提供的使用python编写新操 ...
- How to create PDF files in a Python/Django application using ReportLab
https://assist-software.net/blog/how-create-pdf-files-python-django-application-using-reportlab CONT ...
- Think Python - Chapter 17 - Classes and methods
17.1 Object-oriented featuresPython is an object-oriented programming language, which means that it ...
- Think Python - Chapter 11 - Dictionaries
Dictionaries A dictionary is like a list, but more general. In a list, the indices have to be intege ...
- Data manipulation primitives in R and Python
Data manipulation primitives in R and Python Both R and Python are incredibly good tools to manipula ...
- caffe2 教程入门(python版)
学习思路 1.先看官方文档,学习如何使用python调用caffe2包,包括 Basics of Caffe2 - Workspaces, Operators, and Nets Toy Regres ...
随机推荐
- node.js cookie session使用教程
众所周知,HTTP 是一个无状态协议,所以客户端每次发出请求时,下一次请求无法得知上一次请求所包含的状态数据,如何能把一个用户的状态数据关联起来呢? cookie 首先产生了 cookie 这门技术来 ...
- IEEE发布2017年编程语言排行榜:Python高居首位,java第三,php第八
2017年7月18日,IEEE Spectrum 发布了第四届顶级编程语言交互排行榜.因为有各种不同语言的排行,所以 IEEE Spectrum 依据不同的变量对流行度进行了排行.据 IEEE Spe ...
- python之路----hashlib模块
在平时生活中,有很多情况下,你在不知不觉中,就用到了hashlib模块,比如:注册和登录认证注册和登录认真过程,就是把注册用的账户密码进行:加密 --> 解密 的过程,在加密.解密过程中,用的了 ...
- UVA302 John's trip(欧拉回路)
UVA302 John's trip 欧拉回路 attention: 如果有多组解,按字典序输出. 起点为每组数据所给的第一条边的编号较小的路口 每次输出完额外换一行 保证连通性 每次输入数据结束后, ...
- RabbitMQ-C 客户端接口使用说明
rabbitmq-c是一个用于C语言的,与AMQP server进行交互的client库.AMQP协议为版本0-9-1.rabbitmq-c与server进行交互前需要首先进行login操作,在操作后 ...
- SIFT在OpenCV中的调用和具体实现(HELU版)
前面我们对sift算法的流程进行简要研究,那么在OpenCV中,sift是如何被调用的?又是如何被实现出来的了? 特别是到了3.0以后,OpenCV对特征点提取这个方面进行了系统重构,那么整个代码结构 ...
- 20145317彭垚_Web基础
20145317彭垚_Web基础 基础知识 Apache一个开放源码的网页服务器,可以在大多数计算机操作系统中运行,由于其多平台和安全性被广泛使用,是最流行的Web服务器端软件之一.它快速.可靠并且可 ...
- vijos 1096 津津的储存计划
题目描述 Description 津津的零花钱一直都是自己管理.每个月的月初妈妈给津津300元钱,津津会预算这个月的花销,并且总能做到实际花销和预算的相同. 为了让津津学习如何储蓄,妈妈提出,津津可以 ...
- C++ 表示一个区间值得方法
C++中不允许这样的写法 85<= score <=100;你要想表示85<=score<=100的话只能这么写score>=85&&score<= ...
- BZOJ2982: combination Lucas
Description LMZ有n个不同的基友,他每天晚上要选m个进行[河蟹],而且要求每天晚上的选择都不一样.那么LMZ能够持续多少个这样的夜晚呢?当然,LMZ的一年有10007天,所以他想知道答案 ...