本文是MXNet的官网案例: Train MLP on MNIST. MXNet所有的模块如下图所示:

第一步: 准备数据

从下面程序可以看出,MXNet里面的数据是一个4维NDArray.

import mxnet as mx

# mxnet.io.MXDataIter, shape=(128,1,28,28)
train = mx.io.MNISTIter(
image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-images-idx3-ubyte',
label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-labels-idx1-ubyte',
batch_size = 128,
data_shape = (784, )
)
# mxnet.io.MXDataIter, shape=(128,1,28,28)
val = mx.io.MNISTIter(
image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-images-idx3-ubyte',
label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-labels-idx1-ubyte',
batch_size = 128,
data_shape = (784, )
)

Second: 符号式编程, 生成一个两层的MLP

# Declare a two-layer MLP
data = mx.symbol.Variable('data') # data layer
fc1 = mx.symbol.FullyConnected(data=data, num_hidden=128) # full connected layer 1
act1 = mx.symbol.Activation(data=fc1, act_type="relu") # activation layer(relu activation function)
fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=64)
act2 = mx.symbol.Activation(data=fc2, act_type="relu")
fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=10)
mlp = mx.symbol.SoftmaxOutput(data=fc3, name="softmax") # Softmax layer

一个CNN网络最基本的几层:

输入层: mx.symbol.Variable()

激活层: mx.symbol.Activation()

Batch正则化: mx.symbol.BatchNorm()

Dropout: mx.symbol.Dropout()

全连接层: mx.symbol.FullyConnected()

池化层: mx.symbol.Pooling()

卷积层: mx.symbol.Convolution()

Softmax输出: mx.symbol.SoftmaxOutput()

LRN: mx.symbol.LRN()

......

mx.symbol.FullyConnected(*args, **kwargs)

功能: 对input作矩阵乘法, 并且加上一个偏置. 将shape为(batch_size, input_dim)的input变成(batch_size, num_hidden)的输出;

输入参数:

  • data:  Symbol类型, 输入数据;
  • weight:  Symbol类型, 权重矩阵;
  • bias:  Symbol类型, 偏置参数;
  • num_hidden: int型, 必要参数, 隐层节点的数目;
  • no_bias: 布尔型, 可选参数, defalut=False, 表示是否不要偏置参数
  • name:  字符串类型, 可选参数, 计算结果symbol的名称;

输出参数:

  • 输出是一个Symbol: the result symbol

Last: 训练以及测试

# Type: mxnet.model.FeedForward
# Train a model on the data
model = mx.model.FeedForward(
symbol = mlp,
num_epoch = 20,
learning_rate = .1
)
model.fit(X = train, eval_data = val) # Predict
model.predict(X = train)

class mxnet.model.FeedForward(sklearn.base.BaseEstimator)

输入参数:

  • symbol: Symbol类型, 网络的symbol结构配置;
  • ctx:
  • num_epoch: int型, 可选参数,是一个训练参数, 训练的迭代次数;
  • epoch_size: 一次epoch使用的batches数目, 默认情况下为(num_train_examples / batch_size)
  • optimizer:q
  • initializer:
  • numpy_batch_size:
  • ......

图2 mxnet.model函数列表

MXNet官网案例分析--Train MLP on MNIST的更多相关文章

  1. ReactJS 官网案例分析

    案例一.聊天室案例 /** * This file provided by Facebook is for non-commercial testing and evaluation * purpos ...

  2. Django官网案例教程

    1.注意:python manage.py runserver 0:8000(侧任何IP均可访问)

  3. Yeoman 官网教学案例:使用 Yeoman 构建 WebApp

    STEP 1:设置开发环境 与yeoman的所有交互都是通过命令行.Mac系统使用terminal.app,Linux系统使用shell,windows系统可以使用cmder/PowerShell/c ...

  4. 对石家庄铁道大学官网UI设计的分析

    在这一周周一,老师给我们讲了PM,通过对PM的学习,我知道了PM 对项目所有功能的把握, 特别是UI.最差的UI, 体现了团队的组织架构:其次, 体现了产品的内部结构:最好, 体现了用户的自然需求.在 ...

  5. 【官网翻译】性能篇(四)为电池寿命做优化——使用Battery Historian分析电源使用情况

    前言 本文翻译自“为电池寿命做优化”系列文档中的其中一篇,用于介绍如何使用Battery Historian分析电源使用情况. 中国版官网原文地址为:https://developer.android ...

  6. spring原理案例-基本项目搭建 01 spring framework 下载 官网下载spring jar包

    下载spring http://spring.io/ 最重要是在特征下面的这段话,需要注意: All avaible features and modules are described in the ...

  7. 针对石家庄铁道大学官网首页的UI分析

    身为一名光荣的铁大铮铮学子,我对铁大的网站首页非常的情有独钟,下面我就石家庄铁道大学的官网首页进行UI分析: 1.在首页最醒目的地方赫然写着石家庄铁道大学七个大字,让人一眼就豁然开朗. 2.网站有EN ...

  8. “深度评测官”——记2020BUAA软工软件案例分析作业

    项目 内容 这个作业属于哪个课程 2020春季计算机学院软件工程(罗杰 任建) 这个作业的要求在哪里 个人博客作业-软件案例分析 我在这个课程的目标是 完成一次完整的软件开发经历并以博客的方式记录开发 ...

  9. Maccms后门分析复现(并非官网的Maccms){10.15 第二十二天}

    该复现参考网络中的文章,该漏洞复现仅仅是为了学习交流,严禁非法使用!!!! Maccms官网:http://www.maccms.cn/ Maccms网站基于PHP+MYSQL的系统,易用性.功能良好 ...

随机推荐

  1. Js闭包函数

    一.变量的作用域要理解闭包,首先必须理解Javascript特殊的变量作用域.变量的作用域无非就是两种:全局变量和局部变量.Javascript语言的特殊之处,就在于函数内部可以直接读取全局变量. ( ...

  2. 装过photoshop后出现configuration error

    1.你用的应该是精简版的PS,找到ps启动图标,点击右键,以管理员身份运行试试. 2.可以右键你的快捷方式,选择兼容性,后面有个选框“以管理员身份运行”,应用,下次就不报错了.

  3. jQuery formValidator使用入门

    使用插件必须加载的文件 //加载jQuery类库 <script type="text/javascript" src="jquery-1.7.1.min.js&q ...

  4. Access数据库连接方式

    网络连接:Provider=Microsoft.ACE.OLEDB.12.0;Data Source=\\server\share\folder\myAccessFile.accdb;标准安全:Pro ...

  5. Report_SRW在RDF中初始化的重要性(案例)

    2015-02-01 Created By BaoXinjian 一.摘要 在开发oracle report(report 6i)的时候,常常会用到fnd_global或fnd_profile来获取当 ...

  6. NeHe OpenGL教程 第十二课:显示列表

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

  7. OC错误

  8. SqlServr进程内存使用增长的解决办法

    SqlServr进程使用的内存缓慢增长是正常的现象,但在服务器长时间不重启或sql服务不重启的情况下,最终,这个进程会耗尽所有的内存,导致所有终端无法正常与数据库交互. 1.设置数据库最大使用内存的值 ...

  9. No matching bean of type [xx] found for dependency: expected at least 1 bean which qualifies as autowire candidate for this dependency

    这个看起来很弱爆的问题其实是因为其他的配置文件中已经出现了为xx定义好的注入.如果用@Autowired就会得到上面的错误 , 但是用@Resource的时候就会看到类似下面的错误 Bean name ...

  10. c# winform快捷键设置

    设置 Form 的 KeyPreview=true 然后在Form 的案件事件里判断按钮类型进行分别调用就可以了 private void Form1_KeyDown(object sender, K ...