虽然 Gluon 提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用 NDArray 来自定义一个 Gluon 的层,从而以后可以被重复调用。

不含模型参数的自定义层

我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和 “模型构造” 中介绍的使用 Block 构造模型类似。

通过继承 Block 自定义了一个将输入减掉均值的层:CenteredLayer 类,并将层的计算放在 forward 函数里。

class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs) def forward(self, x):
return x - x.mean() layer = CenteredLayer()
layer(nd.array([1, 2, 3, 4, 5]))
# output
[-2. -1. 0. 1. 2.]
<NDArray 5 @cpu(0)>

我们也可以用它来构造更复杂的模型

net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer())
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))
y.mean()

含模型参数的自定义层

在自定义层的时候我们还可以使用 Block 自带的 ParameterDict 类型的成员变量 params。顾名思义,这是一个由字符串类型的参数名字映射到 Parameter 类型的模型参数的字典。我们可以通过 get 函数从 ParameterDict 创建 Parameter。

params = gluon.ParameterDict()
params.get('param2', shape=(2, 3))
params
# ouput
(
Parameter param2 (shape=(2, 3), dtype=<class 'numpy.float32'>)
)

现在我们看下如何实现一个含权重参数和偏差参数的全连接层。它使用 ReLU 作为激活函数。其中 in_units 和 units 分别是输入单元个数和输出单元个数。

class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,)) def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear)

我们实例化 MyDense 类来看下它的模型参数。

# units:该层的输出个数;in_units:该层的输入个数。
dense = MyDense(units=5, in_units=10)
dense.params
# output
mydense0_ (
Parameter mydense0_weight (shape=(10, 5), dtype=<class 'numpy.float32'>)
Parameter mydense0_bias (shape=(5,), dtype=<class 'numpy.float32'>)
)

我们也可以使用自定义层构造模型。它用起来和 Gluon 的其他层很类似。

net = nn.Sequential()
net.add(MyDense(32, in_units=64))
net.add(MyDense(2, in_units=32))
net.initialize()
net(nd.random.uniform(shape=(2, 64)))

MXNET:深度学习计算-自定义层的更多相关文章

  1. MXNet深度学习库简介

    MXNet深度学习库简介 摘要: MXNet是一个深度学习库, 支持C++, Python, R, Scala, Julia, Matlab以及JavaScript等语言; 支持命令和符号编程; 可以 ...

  2. MXNET:深度学习计算-模型构建

    进入更深的层次:模型构造.参数访问.自定义层和使用 GPU. 模型构建 在多层感知机的实现中,我们首先构造 Sequential 实例,然后依次添加两个全连接层.其中第一层的输出大小为 256,即隐藏 ...

  3. 深度学习中 batchnorm 层是咋回事?

    作者:Double_V_ 来源:CSDN 原文:https://blog.csdn.net/qq_25737169/article/details/79048516 版权声明:本文为博主原创文章,转载 ...

  4. Caffe深度学习计算框架

    Caffe | Deep Learning Framework是一个清晰而高效的深度学习框架,其作者是博士毕业于UC Berkeley的 Yangqing Jia,目前在Google工作.Caffe是 ...

  5. MXNET:深度学习计算-模型参数

    我们将深入讲解模型参数的访问和初始化,以及如何在多个层之间共享同一份参数. 之前我们一直在使用默认的初始函数,net.initialize(). from mxnet import init, nd ...

  6. MXNET:深度学习计算-GPU

    mxnet的设备管理 MXNet 使用 context 来指定用来存储和计算的设备,例如可以是 CPU 或者 GPU.默认情况下,MXNet 会将数据创建在主内存,然后利用 CPU 来计算.在 MXN ...

  7. mxnet深度学习实战学习笔记-9-目标检测

    1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...

  8. DeepLearning.ai学习笔记(一)神经网络和深度学习--Week3浅层神经网络

    介绍 DeepLearning课程总共五大章节,该系列笔记将按照课程安排进行记录. 另外第一章的前两周的课程在之前的Andrew Ng机器学习课程笔记(博客园)&Andrew Ng机器学习课程 ...

  9. deeplearning.ai 神经网络和深度学习 week3 浅层神经网络 听课笔记

    1. 第i层网络 Z[i] = W[i]A[i-1] + B[i],A[i] = f[i](Z[i]). 其中, W[i]形状是n[i]*n[i-1],n[i]是第i层神经元的数量: A[i-1]是第 ...

随机推荐

  1. Linux系统 vi/vim文本编辑器

    Linux系统 vi/vim文本编辑器 (一)Vim/Vi简介 (二)Vim/Vi工作模式 (三)Vim/Vi基本使用 (四)Vim/Vi应用技巧 (一)Vim/Vi简介 Vim/Vi是一个功能强大的 ...

  2. web扫描工具-Nikto介绍与使用

    Nikto Perl语言开发的开源Web安全扫描器 web扫描模式:截断代理主动扫描 可以扫描的方面:软件版本搜索存在安全隐患的文件服务器配置漏洞WEB Application层面的安全隐患避免404 ...

  3. Flutter常用组件(Widget)解析-Image

    显示图片的组件 以下是几种加载图片路径方式: Image.asset 加载asset项目资源中的文件 Image.network 加载网络资源图片,通过url加载 Image.file 加载本地文件中 ...

  4. 【Java并发核心五】Future 和 Callable

    默认情况下,线程Thread对象不具有返回值的功能,如果在需要取得返回值的情况下会极为不方便.jdk1.5中可以使用Future 和 Callable 来获取线程返回值. Callable 可以 看成 ...

  5. 进程队列补充-创建进程队列的另一个类JoinableQueue

    JoinableQueue同样通过multiprocessing使用. 创建队列的另外一个类: JoinableQueue([maxsize]):这就像是一个Queue对象,但队列允许项目的使用者通知 ...

  6. loj#6076「2017 山东一轮集训 Day6」三元组 莫比乌斯反演 + 三元环计数

    题目大意: 给定\(a, b, c\),求\(\sum \limits_{i = 1}^a \sum \limits_{j = 1}^b \sum \limits_{k = 1}^c [(i, j) ...

  7. ECS之旅——常用的linux指令

    玩过Linux的人都会知道,Linux中的命令的确是非常多,但是玩过Linux的人也从来不会因为Linux的命令如此之多而烦恼,因为我们只需要掌握我们最常用的命令就可以了.当然你也可以在使用时去找一下 ...

  8. php get_magic_quotes_gpc()函数使用

    magic_quotes_gpc函数在php中的作用是判断解析用户提示的数据,如包括有:post.get.cookie过来的数据增加转义字符"\",以确保这些数据不会引起程序,特别 ...

  9. Java虚拟机详解----常用JVM配置参数

    本文主要内容: Trace跟踪参数 堆的分配参数 栈的分配参数 零.在IDE的后台打印GC日志: 既然学习JVM,阅读GC日志是处理Java虚拟机内存问题的基础技能,它只是一些人为确定的规则,没有太多 ...

  10. C语言学习中遇到的小问题(一)

    C语言小白学习C语言的记录1 一.scanf一次性接收连续的数字 1.已知数量,且个数较少:scanf("%d%d%d",&a&b&c); 2.已知数量,但 ...