从零开始学习MXnet(三)之Model和Module
在我们在MXnet中定义好symbol、写好dataiter并且准备好data之后,就可以开开心的去训练了。一般训练一个网络有两种常用的策略,基于model的和基于module的。今天,我想谈一谈他们的使用。
一、Model
按照老规矩,直接从官方文档里面拿出来的代码看一下:
# configure a two layer neuralnetwork
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
# create a model using sklearn-style two-step way
#创建一个model
model = mx.model.FeedForward(
softmax,
num_epoch=num_epoch,
learning_rate=0.01)
#开始训练
model.fit(X=data_set)
具体的API参照http://mxnet.io/api/python/model.html。
然后呢,model这部分就说完了。。。之所以这么快主要有两个原因:
1.确实东西不多,一般都是查一查文档就可以了。
2.model的可定制性不强,一般我们是很少使用的,常用的还是module。
二、Module
Module真的是一个很棒的东西,虽然深入了解后,你会觉得“哇,好厉害,但是感觉没什么鸟用呢”这种想法。。实际上我就有过,现在回想起来,从代码的设计和使用的角度来讲,Module确实是一个非常好的东西,它可以为我们的网络计算提高了中级、高级的接口,这样一来,就可以有很多的个性化配置让我们自己来做了。
Module有四种状态:
1.初始化状态,就是显存还没有被分配,基本上啥都没做的状态。
2.binded,在把data和label的shape传到Bind函数里并且执行之后,显存就分配好了,可以准备好计算能力。
3.参数初始化。就是初始化参数
3.Optimizer installed 。就是传入SGD,Adam这种optimuzer中去进行训练
先上一个简单的代码:
import mxnet as mx
# construct a simple MLP
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
# construct the module
mod = mx.mod.Module(out)
mod.bind(data_shapes=train_dataiter.provide_data,
label_shapes=train_dataiter.provide_label)
mod.init_params()
mod.fit(train_dataiter, eval_data=eval_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
num_epoch=n_epoch)
分析一下:首先是定义了一个简单的MLP,symbol的名字就叫做out,然后可以直接用mx.mod.Module来创建一个mod。之后mod.bind的操作是在显卡上分配所需的显存,所以我们需要把data_shapehe label_shape传递给他,然后初始化网络的参数,再然后就是mod.fit开始训练了。这里补充一下。fit这个函数我们已经看见两次了,实际上它是一个集成的功能,mod.fit()实际上它内部的核心代码是这样的:
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
for nbatch, data_batch in enumerate(train_data):
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch) #网络进行一次前向传播和后向传播
self.update() #更新参数
self.update_metric(eval_metric, data_batch.label) #更新metric if monitor is not None:
monitor.toc_print() if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
正是因为module里面我们可以使用很多intermediate的interface,所以可以做出很多改进,举个最简单的例子:如果我们的训练网络是大小可变怎么办? 我们可以实现一个mutumodule,基本上就是,每次data的shape变了的时候,我们就重新bind一下symbol,这样训练就可以照常进行了。
总结:实际上学一个框架的关键还是使用它,要说诀窍的话也就是多看看源码和文档了,我写这些博客的目的,一是为了记录一些东西,二是让后来者少走一些弯路。所以有些东西不会说的很全。。
从零开始学习MXnet(三)之Model和Module的更多相关文章
- 从零开始学习jQuery (三) 管理jQuery包装集
本系列文章导航 从零开始学习jQuery (三) 管理jQuery包装集 一.摘要 在使用jQuery选择器获取到jQuery包装集后, 我们需要对其进行操作. 本章首先讲解如何动态的创建元素, 接着 ...
- 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法
写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...
- 从零开始学习MXnet(一)
最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家. 我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步. 一 ...
- 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导
这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...
- 从零开始学习Vue(三)
我们从一个例子来学习组件,vuejs2.0实战:仿豆瓣app项目,创建自定义组件tabbar 这个例子用到其他组件,对于初学者来说,一下子要了解那么多组件的使用,会变得一头雾水.所以我把这个例子改写了 ...
- 从零开始学习MXnet(二)之dataiter
MXnet的设计结构是C++做后端运算,python.R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面.今天我们先谈一谈Data iter ...
- oracle从零开始学习笔记 三
高级查询 随机返回5条记录 select * from (select ename,job from emp order by dbms_random.value())where rownum< ...
- 从零开始学习jQuery(转)
本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery ( ...
- 从零开始学习jQuery
转自:http://www.cnblogs.com/zhangziqiu/archive/2009/04/30/jQuery-Learn-1.html 本系列文章导航 从零开始学习jQuery (一) ...
随机推荐
- STM32(7)——通用定时器PWM输出
1.TIMER输出PWM基本概念 脉冲宽度调制(PWM),是英文“Pulse Width Modulation”的缩写,简称脉宽调制,是利用微处理器的数字输出来对模拟电路进行控制的一种 ...
- unity独立游戏开发日志2018/09/22
f::很头痛之前rm做的游戏在新电脑工程打不开了...只能另起炉灶... 还不知道新游戏叫什么名...暂且叫方块世界.(素材已经授权) 首先是规划下场景和素材文件夹的建立. unity常用的文件夹有: ...
- Redis的RDB与AOF介绍(Redis DateBase与Append Only File)
RedisRDB介绍(Redis DateBase) 在指定的时间间隔内将内存中的数据集快照写入磁盘,也就是行话讲的Snapshot快照,它恢复时是将快照文件直接读到内存里 一.是什么? Redis会 ...
- go学习笔记-语言指针
语言指针 定义及使用 变量是一种使用方便的占位符,用于引用计算机内存地址.取地址符是 &,放到一个变量前使用就会返回相应变量的内存地址. 一个指针变量指向了一个值的内存地址.类似于变量和常量, ...
- guacamole实现剪切复制
主要功能是实现把堡垒机的内容复制到浏览器端,把浏览器端的文本复制到堡垒机上. 借助一个中间的文本框,现将堡垒机内容复制到一个文本框,然后把文本框内容复制出来.或者将需要传递到堡垒机的内容先复制到文本框 ...
- AV Foundation 实现文字转语音
AV Foundation 主要框架 CoreAudio 音频处理框架 扩展学习:<Learning CoreAudio> CoreVideo 视频处理的管道模式,逐帧访问 CoreMed ...
- Clean Code 《代码整洁之道》前四章读书笔记
第一章: 整洁的代码只做好一件事 减少重复代码 提高表达力 提早构建简单抽象 让营地比你来时更干净 第二章:有意义的命名 名副其实:如果名称需要注释来补充,就不算是名副其实. ...
- day-9 sklearn库和python自带库实现最近邻KNN算法
K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一.该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的 ...
- Chrome 与 Firefox-Dev 的 DevTools
不管是做爬虫还是写 Web App,Chrome 和 Firefox 的 DevTools 都是超常用的,但是经常发现别人的截图有什么字段我找不到,别人的什么功能我的 Chrome 没有,仔细一搜索才 ...
- Android Studio环境解读
一.使用IDE开发APP的流程 要熟悉一个新的IDE,可依次完成以下流程: 二.相关术语解析 Dalvik: Android特有的虚拟机,和JVM不同,Dalvik虚拟机非常适合在移动终端上使用! A ...