https://blog.csdn.net/weixin_34260991/article/details/87106463

这里使用比较简单的定义方式,只是在原有的激活函数调用中加入。

准备工作
下载MXNet源代码,确认可以顺利编译通过。推荐在Linux下进行此操作:

https://mxnet.incubator.apache.org/get_started/install.html

编写激活函数先前和先后传递
在src/operator/mshadow_op.h里面,加入新的激活函数向前传递和向后的函数:

/*!
* \brief RBF Unit
* \author Yuzhong Liu
*/
struct rbf {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x) {
return DType(expf(-x*x));
}
};

struct rbf_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType a) {
return DType(-2 * x * a);
}
};
添加调用方法
在src/operator/leaky_relu-inl.h里面,激活函数的调用方式:

namespace leakyrelu {
enum LeakyReLUOpInputs {kData, kGamma};
enum LeakyReLUOpOutputs {kOut, kMask};
# 定义新的激活函数名称
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kRBF};
enum LeakyReLUOpResource {kRandom};
} // namespace leakyrelu

struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
// use int for enumeration
int act_type;
float slope;
float lower_bound;
float upper_bound;
DMLC_DECLARE_PARAMETER(LeakyReLUParam) {
DMLC_DECLARE_FIELD(act_type).set_default(leakyrelu::kLeakyReLU)
.add_enum("rrelu", leakyrelu::kRReLU)
.add_enum("leaky", leakyrelu::kLeakyReLU)
.add_enum("prelu", leakyrelu::kPReLU)
.add_enum("elu", leakyrelu::kELU)
# 添加激活函数枚举
.add_enum("rbf", leakyrelu::kRBF)
.describe("Activation function to be applied.");
DMLC_DECLARE_FIELD(slope).set_default(0.25f)
.describe("Init slope for the activation. (For leaky and elu only)");
DMLC_DECLARE_FIELD(lower_bound).set_default(0.125f)
.describe("Lower bound of random slope. (For rrelu only)");
DMLC_DECLARE_FIELD(upper_bound).set_default(0.334f)
.describe("Upper bound of random slope. (For rrelu only)");
}
};

template<typename xpu>
class LeakyReLUOp : public Operator {
public:
explicit LeakyReLUOp(LeakyReLUParam param) {
param_ = param;
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(in_data.size(), expected);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 3> data;
Tensor<xpu, 3> out;
Tensor<xpu, 3> mask;
Tensor<xpu, 1> weight;
int n = in_data[leakyrelu::kData].shape_[0];
int k = in_data[leakyrelu::kData].shape_[1];
Shape<3> dshape = Shape3(n, k, in_data[leakyrelu::kData].Size()/n/k);
data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
out = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, real_t>(dshape, s);
}
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, param_.slope));
break;
}
case leakyrelu::kPReLU: {
weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
Assign(out, req[leakyrelu::kOut],
F<mshadow_op::xelu>(data, broadcast<1>(weight, out.shape_)));
break;
}
case leakyrelu::kRReLU: {
if (ctx.is_train) {
Random<xpu>* prnd = ctx.requested[leakyrelu::kRandom].get_random<xpu, real_t>(s);
mask = prnd->uniform(mask.shape_);
mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound;
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, mask));
} else {
const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
}
break;
}
case leakyrelu::kELU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, param_.slope));
break;
}
# RBF向前
case leakyrelu::kRBF: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::rbf>(data));
break;
}
default:
LOG(FATAL) << "Not implmented";
}
}

virtual void Backward(const OpContext & ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(out_grad.size(), 1U);
CHECK_EQ(req.size(), expected);
CHECK_EQ(in_data.size(), expected);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 3> output;
Tensor<xpu, 3> data;
Tensor<xpu, 3> gdata;
Tensor<xpu, 3> grad;
Tensor<xpu, 3> mask;
Tensor<xpu, 1> weight;
Tensor<xpu, 1> grad_weight;
int n = out_grad[leakyrelu::kOut].shape_[0];
int k = out_grad[leakyrelu::kOut].shape_[1];
Shape<3> dshape = Shape3(n, k, out_grad[leakyrelu::kOut].Size()/n/k);
grad = out_grad[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
gdata = in_grad[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
output = out_data[leakyrelu::kOut].get_with_shape<xpu, 3, real_t>(dshape, s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, real_t>(dshape, s);
}
if (param_.act_type == leakyrelu::kPReLU) {
data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
}
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, param_.slope) * grad);
break;
}
case leakyrelu::kPReLU: {
weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
grad_weight = sumall_except_dim<1>(F<prelu_grad>(data) * grad);
gdata = F<mshadow_op::xelu_grad>(data, broadcast<1>(weight, data.shape_)) * grad;
break;
}
case leakyrelu::kRReLU: {
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, mask) * grad);
break;
}
case leakyrelu::kELU: {
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, param_.slope) * grad);
break;
}
# RBF向前
case leakyrelu::kRBF: {
data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, real_t>(dshape, s);
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::rbf_grad>(data, output) * grad);
break;
}
default:
LOG(FATAL) << "Not implmented";
}
}

private:
LeakyReLUParam param_;
}; // class LeakyReLUOp
从重新编译,并测试
import mxnet as mx
from mxnet import autograd
a = mx.nd.random_uniform(-1, 1, shape=[3, 3]) +0.5
a.attach_grad()

with autograd.record():
b = mx.nd.LeakyReLU(data=a, act_type='rbf')

print a, b
参考资料
https://mxnet.incubator.apache.org/how_to/new_op.html
http://blog.csdn.net/qq_20965753/article/details/66975622?utm_source=debugrun&utm_medium=referral
---------------------
作者:weixin_34260991
来源:CSDN
原文:https://blog.csdn.net/weixin_34260991/article/details/87106463
版权声明:本文为博主原创文章,转载请附上博文链接!

MXNet 定义新激活函数(Custom new activation function)的更多相关文章

  1. 浅谈深度学习中的激活函数 - The Activation Function in Deep Learning

    原文地址:http://www.cnblogs.com/rgvb178/p/6055213.html版权声明:本文为博主原创文章,未经博主允许不得转载. 激活函数的作用 首先,激活函数不是真的要去激活 ...

  2. The Activation Function in Deep Learning 浅谈深度学习中的激活函数

    原文地址:http://www.cnblogs.com/rgvb178/p/6055213.html 版权声明:本文为博主原创文章,未经博主允许不得转载. 激活函数的作用 首先,激活函数不是真的要去激 ...

  3. 《Noisy Activation Function》噪声激活函数(一)

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/51736830 Noisy Activa ...

  4. 激活函数:Swish: a Self-Gated Activation Function

    今天看到google brain 关于激活函数在2017年提出了一个新的Swish 激活函数. 叫swish,地址:https://arxiv.org/abs/1710.05941v1 pytorch ...

  5. caffe中的sgd,与激活函数(activation function)

    caffe中activation function的形式,直接决定了其训练速度以及SGD的求解. 在caffe中,不同的activation function对应的sgd的方式是不同的,因此,在配置文 ...

  6. TensorFlow Activation Function 1

    部分转自:https://blog.csdn.net/caicaiatnbu/article/details/72745156 激活函数(Activation Function)运行时激活神经网络中某 ...

  7. ML 激励函数 Activation Function (整理)

    本文为内容整理,原文请看url链接,感谢几位博主知识来源 一.什么是激励函数 激励函数一般用于神经网络的层与层之间,上一层的输出通过激励函数的转换之后输入到下一层中.神经网络模型是非线性的,如果没有使 ...

  8. 转载-聊一聊深度学习的activation function

    目录 1. 背景 2. 深度学习中常见的激活函数 2.1 Sigmoid函数 2.2 tanh函数 2.3 ReLU函数 2.4 Leaky ReLu函数 2.5 ELU(Exponential Li ...

  9. TensorFlow实战第一课(session、Variable、Placeholder、Activation Function)

    莫烦tensorflow教学 1.session会话控制 Tensorflow 中的Session, Session是 Tensorflow 为了控制,和输出文件的执行的语句. 运行session.r ...

随机推荐

  1. nodejs SSL Error: CERT_UNTRUSTED while using npm command 错误

    SSH 使用错误,其实我们关掉HTTPS就好了 npm config set strict-ssl false 或者 npm config set registry="http://regi ...

  2. 【RAC】 RAC For W2K8R2 安装--dbca创建数据库(七)

    [RAC] RAC For W2K8R2 安装--dbca创建数据库(七) 一.1  BLOG文档结构图 一.2  前言部分 一.2.1  导读 各位技术爱好者,看完本文后,你可以掌握如下的技能,也可 ...

  3. Android笔记(七十六) 点菜DEMO

    一个朋友让看一下他的代码,一个点菜的功能,他和我一样,初学者,代码比我的都混乱,也是醉了,干脆想着自己写个demo给他看,原本想着听简单,半个小时应该就可以搞定,真正写的时候,画了3h+,汗颜... ...

  4. docker复制文件到宿主机

    从主机复制到容器 sudo docker cp host_path containerID:container_path 从容器复制到主机 sudo docker cp containerID:con ...

  5. mysql数据库查询缓存总结

    概述 查询缓存(Query Cache,简称QC),存储SELECT语句及其产生的数据结果.闲来无事,做一下这块的总结,也做个备忘! 工作原理 查询缓存工作原理如下: 缓存SELECT操作的结果集和S ...

  6. CentOS 7源码安装MYSQL-5.6

    一. 环境准备 Linux CentOS7.3系统一台主机即可: MYSQL官网:https://www.mysql.com/ MYSQL软件下载:http://ftp.kaist.ac.kr/mys ...

  7. 【转】GnuPG使用介绍

    一.什么是 GPG 要了解什么是 GPG,就要先了解 PGP. 1991 年,程序员 Phil Zimmermann 为了避开政府监视,开发了加密软件 PGP.这个软件非常好用,迅速流传开来,成了许多 ...

  8. 解决页面初始化vue加载代码问题

    <style type="text/css"> /* 解决页面初始化vue加载代码问题 */ [v-cloak] { display: none; } </sty ...

  9. 10.31-11.1Test(未完)

    10.31-11.1Test 题目 描述 做法 \(BSOJ5177\) 求在\(n\)个数里选\(K\)个的所有方案的异或和之和 按位讨论,组合数算 \(BSOJ5178\) 化简\(\displa ...

  10. php 正则达达示中的模式修正符

    我们通过元字符和原子完成了正则表达示的入门.有一些特殊情况我们依然需要来处理.深圳dd马达 如果abc在第二行的开始处如何匹配?我不希望正则表达示特别贪婪的匹配全部,只匹配一部份怎么办? 这个时候,我 ...