一、不含参数层

通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里,

from mxnet import nd, gluon
from mxnet.gluon import nn class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs) def forward(self, x):
return x - x.mean() # 直接使用这个层
layer = CenteredLayer()
# layer(nd.array([1, 2, 3, 4, 5])) # 构建更复杂模型
net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer()) # 初始化、运行……
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))

二、含参数层

注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定

见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,

from mxnet import gluon
from mxnet.gluon import nn class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,)) def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear) # 实际运行
dense = MyDense(5, in_units=10)

如果不想使用ParameterDict类则需要一下操作

# self.weight = self.params.get('weight', shape=(in_units, units))
self.weight = gluon.Parameter('weight', shape=(in_units, units))
self.params.update({'weight':self.weight})

否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。

有关这一点详见下面:

    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
super(SSD, self).__init__(**kwargs)
self.vgg_conv = nn.Sequential()
self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
[self.vgg_conv.add(repeat(*conv_arch[i])) for i in range(1, len(conv_arch))]
# 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
# self.vgg_conv = tuple([repeat(*conv_arch[i])
# for i in range(len(conv_arch))])
# 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
# _ = [self.params.update(block.collect_params()) for block in self.vgg_conv] def forward(self, x, feat_layers):
end_points = {'block0': x}
for (index, block) in enumerate(self.vgg_conv):
end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
return end_points

属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。

测试代码:

if __name__ == '__main__':

    ssd = SSD(conv_arch=((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)),
dropout_keep_prob=0.5)
ssd.initialize()
X = mx.ndarray.random.uniform(shape=(1, 1, 304, 304))
import pprint as pp
pp.pprint([x[1].shape for x in ssd(X).items()])

自行验证即可。

『MXNet』第四弹_Gluon自定义层的更多相关文章

  1. 『MXNet』第三弹_Gluon模型参数

    MXNet中含有init包,它包含了多种模型初始化方法. from mxnet import init, nd from mxnet.gluon import nn net = nn.Sequenti ...

  2. 『MXNet』第六弹_Gluon性能提升

    一.符号式编程 1.命令式编程和符号式编程 命令式: def add(a, b): return a + b def fancy_func(a, b, c, d): e = add(a, b) f = ...

  3. 『MXNet』第六弹_Gluon性能提升 静态图 动态图 符号式编程 命令式编程

    https://www.cnblogs.com/hellcat/p/9084894.html 目录 一.符号式编程 1.命令式编程和符号式编程 2.MXNet的符号式编程 二.惰性计算 用同步函数实际 ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. 『MXNet』第十弹_物体检测SSD

    全流程地址 一.辅助API介绍 mxnet.image.ImageDetIter 图像检测迭代器, from mxnet import image from mxnet import nd data_ ...

  6. 『MXNet』第八弹_数据处理API_下_Image IO专题

    想学习MXNet的同学建议看一看这位博主的博客,受益良多. 在本节中,我们将学习如何在MXNet中预处理和加载图像数据. 在MXNet中加载图像数据有4种方式. 使用 mx.image.imdecod ...

  7. 『MXNet』第八弹_数据处理API_上

    一.Gluon数据加载 下面的两个dataset处理类一般会成对出现,两个都可做预处理,但是由于后面还可能用到原始图片,.ImageFolderDataset不加预处理的话可以满足,所以建议在.Dat ...

  8. 『MXNet』第十一弹_符号式编程初探

    一.符号分类 符号对我们想要进行的计算进行了描述, 下图展示了符号如何对计算进行描述. 我们定义了符号变量A, 符号变量B, 生成了符号变量C, 其中, A, B为参数节点, C为内部节点! mxne ...

  9. 『MXNet』第七弹_多GPU并行程序设计

    资料原文 一.概述思路 假设一台机器上有个GPU.给定需要训练的模型,每个GPU将分别独立维护一份完整的模型参数. 在模型训练的任意一次迭代中,给定一个小批量,我们将该批量中的样本划分成份并分给每个G ...

随机推荐

  1. What is event bubbling and capturing?

    What is event bubbling and capturing? 答案1 Event bubbling and capturing are two ways of event propaga ...

  2. Learning to Compare Image Patches via Convolutional Neural Networks --- Reading Summary

    Learning to Compare Image Patches via Convolutional Neural Networks ---  Reading Summary 2017.03.08 ...

  3. 使用Java Api 操作HDFS

    如题 我就是一个标题党  就是使用JavaApi操作HDFS,使用的是MAVEN,操作的环境是Linux 首先要配置好Maven环境,我使用的是已经有的仓库,如果你下载的jar包 速度慢,可以改变Ma ...

  4. Twitter REST API, Streaming API

    原文链接           用Twitter自己的话来说:   REST API The REST API provides simple interfaces for most Twitter f ...

  5. 【译】第5节---Code First约定

    原文:http://www.entityframeworktutorial.net/code-first/code-first-conventions.aspx 我们在上一节中已经看到了EF Code ...

  6. HTML XHTML HTNL5 简介

    XHTML 是HTML与XML(扩展标记语言)的结合物 包含了所有与XML语法结合的HTML 4.01元素 XHTML 指可扩展超文本标签语言(EXtensible HyperText Markup ...

  7. 用python读写excel的强大工具:openpyxl

    最近看到好几次群里有人问xlwt.wlrd的问题,怎么说呢,如果是office2007刚出来,大家用xlsx文件用不习惯,还可以理解,这都10年过去了喂,就算没有进化到office2016,还在用of ...

  8. Codeforces 786 B. Legacy

    题目链接:http://codeforces.com/contest/786/problem/B 典型线段树优化连边,线段树上的每一个点表示这个区间的所有点,然后边数就被优化为了至多${nlogn}$ ...

  9. 【SQL Prompt】SQL Prompt7.2下载及破解教程

    基本介绍 SQL Prompt能根据数据库的对象名称,语法和用户编写的代码片段自动进行检索,智能的为用户提供唯一合适的代码选择.自动脚本设置为用户提供了简单的代码易读性--这在开发者使用的是不大熟悉的 ...

  10. 力扣(LeetCode) 771. 宝石与石头

    给定字符串J 代表石头中宝石的类型,和字符串 S代表你拥有的石头. S 中每个字符代表了一种你拥有的石头的类型,你想知道你拥有的石头中有多少是宝石. J 中的字母不重复,J 和 S中的所有字符都是字母 ...