在我们在MXnet中定义好symbol、写好dataiter并且准备好data之后,就可以开开心的去训练了。一般训练一个网络有两种常用的策略,基于model的和基于module的。今天,我想谈一谈他们的使用。

一、Model

  按照老规矩,直接从官方文档里面拿出来的代码看一下:

  

 # configure a two layer neuralnetwork
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
# create a model using sklearn-style two-step way
#创建一个model
model = mx.model.FeedForward(
softmax,
num_epoch=num_epoch,
learning_rate=0.01)
#开始训练
model.fit(X=data_set)

  具体的API参照http://mxnet.io/api/python/model.html。

  然后呢,model这部分就说完了。。。之所以这么快主要有两个原因:

    1.确实东西不多,一般都是查一查文档就可以了。

    2.model的可定制性不强,一般我们是很少使用的,常用的还是module。

二、Module

  Module真的是一个很棒的东西,虽然深入了解后,你会觉得“哇,好厉害,但是感觉没什么鸟用呢”这种想法。。实际上我就有过,现在回想起来,从代码的设计和使用的角度来讲,Module确实是一个非常好的东西,它可以为我们的网络计算提高了中级、高级的接口,这样一来,就可以有很多的个性化配置让我们自己来做了。

  Module有四种状态:

    1.初始化状态,就是显存还没有被分配,基本上啥都没做的状态。

    2.binded,在把data和label的shape传到Bind函数里并且执行之后,显存就分配好了,可以准备好计算能力。

    3.参数初始化。就是初始化参数

    3.Optimizer installed 。就是传入SGD,Adam这种optimuzer中去进行训练 

 先上一个简单的代码:

  

import mxnet as mx

    # construct a simple MLP
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') # construct the module
mod = mx.mod.Module(out) mod.bind(data_shapes=train_dataiter.provide_data,
label_shapes=train_dataiter.provide_label) mod.init_params()
mod.fit(train_dataiter, eval_data=eval_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
num_epoch=n_epoch)

  分析一下:首先是定义了一个简单的MLP,symbol的名字就叫做out,然后可以直接用mx.mod.Module来创建一个mod。之后mod.bind的操作是在显卡上分配所需的显存,所以我们需要把data_shapehe label_shape传递给他,然后初始化网络的参数,再然后就是mod.fit开始训练了。这里补充一下。fit这个函数我们已经看见两次了,实际上它是一个集成的功能,mod.fit()实际上它内部的核心代码是这样的:

  

for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
for nbatch, data_batch in enumerate(train_data):
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch) #网络进行一次前向传播和后向传播
self.update() #更新参数
self.update_metric(eval_metric, data_batch.label) #更新metric if monitor is not None:
monitor.toc_print() if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)

  正是因为module里面我们可以使用很多intermediate的interface,所以可以做出很多改进,举个最简单的例子:如果我们的训练网络是大小可变怎么办? 我们可以实现一个mutumodule,基本上就是,每次data的shape变了的时候,我们就重新bind一下symbol,这样训练就可以照常进行了。

  

  总结:实际上学一个框架的关键还是使用它,要说诀窍的话也就是多看看源码和文档了,我写这些博客的目的,一是为了记录一些东西,二是让后来者少走一些弯路。所以有些东西不会说的很全。。

  

从零开始学习MXnet(三)之Model和Module的更多相关文章

  1. 从零开始学习jQuery (三) 管理jQuery包装集

    本系列文章导航 从零开始学习jQuery (三) 管理jQuery包装集 一.摘要 在使用jQuery选择器获取到jQuery包装集后, 我们需要对其进行操作. 本章首先讲解如何动态的创建元素, 接着 ...

  2. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  3. 从零开始学习MXnet(一)

    最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...

  4. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  5. 从零开始学习Vue(三)

    我们从一个例子来学习组件,vuejs2.0实战:仿豆瓣app项目,创建自定义组件tabbar 这个例子用到其他组件,对于初学者来说,一下子要了解那么多组件的使用,会变得一头雾水.所以我把这个例子改写了 ...

  6. 从零开始学习MXnet(二)之dataiter

    MXnet的设计结构是C++做后端运算,python.R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面.今天我们先谈一谈Data iter ...

  7. oracle从零开始学习笔记 三

    高级查询 随机返回5条记录 select * from (select ename,job from emp order by dbms_random.value())where rownum< ...

  8. 从零开始学习jQuery(转)

    本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...

  9. 从零开始学习jQuery

    转自:http://www.cnblogs.com/zhangziqiu/archive/2009/04/30/jQuery-Learn-1.html 本系列文章导航 从零开始学习jQuery (一) ...

随机推荐

  1. Python函数中的参数

    形参:形式参数 实参:实际参数 1.普通参数:严格按照顺序将实参赋值给形参. 2.默认参数:必须放置在参数列表的最后. 3.指定参数:将实参赋值给制定参数. 4.动态参数: *:默认将传入的参数,全部 ...

  2. linux中常用命令总结

    一关机/重启/注销 关机 shutdown -h now //立即关机 重启 shutdown -r now //立即重启 reboot 重新启动 注销 logout //退出注销当前用户窗口 exi ...

  3. 【Leetcode】413. Arithmetic Slices

    Description A sequence of number is called arithmetic if it consists of at least three elements and ...

  4. React 省市区三级联动

    省市区所对应的数据来自:http://www.zgguan.com/zsfx/jsjc/6541.html react中的代码是: export default class AddReceive ex ...

  5. 安装sql server

    因为电脑中只有mysql数据库,所以昨天准备安装一个sql server.安装中出现了许多问题,首先第一遍的时候,安装组件中没有勾选管理工具这个选项,所以在最后的时候,文件夹中只有配置管理器,没有数据 ...

  6. P2340 奶牛会展(状压dp)

    P2340 奶牛会展 题目背景 奶牛想证明它们是聪明而风趣的.为此,贝西筹备了一个奶牛博览会,她已经对N 头奶牛进行 了面试,确定了每头奶牛的智商和情商. 题目描述 贝西有权选择让哪些奶牛参加展览.由 ...

  7. spring读取properties和其他配置文件的几种方式

    1.因为spring容器的一些机制,在读取配置文件进行数据库的配置等等是很有必要的,所以我们要考虑配置文件的的读取方式以及各个方式的实用性 2.配置文件的读取方式我这里介绍2种,目的是掌握这2种就可以 ...

  8. VM打开虚拟机文件报错

    用VM打开以前的虚拟机文件报错 Cannot open the disk 'F:/****.vmdk' or one of the snapshot disks it depends on. 这种问题 ...

  9. web开发微信文章目录

    Web开发微信文章目录 2015-12-13 Web开发 本文是Web开发微信的文章目录.通过目录查看文章编号,回复文章编号就能查看文章全文. 回复编号查看全文,搜索分类名可以获得该分类下的文章.   ...

  10. VS2013生产过程问题及解决

    TRK0002错误 现象:编译器.链接器交替报错,不能正常生成 环境:Win8.1 + VS2013 + 百度杀毒 解决:退出百度杀毒,重启VS,再进行生成 修订:发现问题依旧,经过多次试验,发现与杀 ...