最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家。

  我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步。

一、MXnet数据预处理

  整个数据预处理的代码都集成在了toosl/im2rec.py中了,这个首先要造出一个list文件,lst文件有三列,分别是index label 图片路径。如下图所示:

  

  我这个label是瞎填的,所以都是0。另外最新的MXnet上面的im2rec是有问题的,它生成的list所有的index都是0,不过据说这个index没什么用.....但我还是改了一下。把yield生成器换成直接append即可。

  执行的命令如下:

    sudo python im2rec.py --list=True /home/erya/dhc/result/try /home/erya/dhc/result/ --recursive=True --shuffle=true --train-ratio=0.8

  每个参数的意义在代码内部都可以查到,简单说一下这里用到的:--list=True说明这次的目的是make list,后面紧跟的是生成的list的名字的前缀,我这里是加了路径,然后是图片所在文件夹的路径,recursive是是否迭代的进入文件夹读取图片,--train-ratio则表示train和val在数据集中的比例。

执行上面的命令后,会得到三个文件:

 

然后再执行下面的命令生成最后的rec文件:

  sudo python im2rec.py /home/erya/dhc/result/try_val.lst  /home/erya/dhc/result --quality=100

以及,sudo python im2rec.py /home/erya/dhc/result/try_train.lst  /home/erya/dhc/result --quality=100

 来生成相应的lst文件的rec文件,参数意义太简单就不说了..看着就明白,result是我存放图片的目录。

 

  这样最终就完成了数据的预处理,简单的说,就是先生成lst文件,这个其实完全可以自己做,而且后期我做segmentation的时候,label就是图片了..

二、非常简单的小demo

先上代码:

  

 import mxnet as mx
import logging
import numpy as np logger = logging.getLogger()
logger.setLevel(logging.DEBUG)#暂时不需要管的log
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"):
conv = mx.symbol.Convolution(data=data, workspace=256,
num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
return conv #我把这个删除到只有一个卷积的操作
def DownsampleFactory(data, ch_3x3):
# conv 3x3
conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))
# pool
pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max')
# concat
concat = mx.symbol.Concat(*[conv, pool])
return concat
def SimpleFactory(data, ch_1x1, ch_3x3):
# 1x1
conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)
# 3x3
conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)
#concat
concat = mx.symbol.Concat(*[conv1x1, conv3x3])
return concat
if __name__ == "__main__":
batch_size = 1
train_dataiter = mx.io.ImageRecordIter(
shuffle=True,
path_imgrec="/home/erya/dhc/result/try_train.rec",
rand_crop=True,
rand_mirror=True,
data_shape=(3,28,28),
batch_size=batch_size,
preprocess_threads=1)#这里是使用我们之前的创造的数据,简单的说就是要自己写一个iter,然后把相应的参数填进去。
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="/home/erya/dhc/result/try_val.rec",
rand_crop=False,
rand_mirror=False,
data_shape=(3,28,28),
batch_size=batch_size,
round_batch=False,
preprocess_threads=1)#同理
data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
fc = mx.symbol.FullyConnected(data=in3a, num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(name='softmax',data=fc)#上面就是定义了一个巨巨巨简单的结构
# For demo purpose, this model only train 1 epoch
# We will use the first GPU to do training
num_epoch = 1
model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
learning_rate=0.05, momentum=0.9, wd=0.00001) #将整个model训练的架构定下来了,类似于caffe里面solver所做的事情。 # we can add learning rate scheduler to the model
# model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
# learning_rate=0.05, momentum=0.9, wd=0.00001,
# lr_scheduler=mx.misc.FactorScheduler(2))
model.fit(X=train_dataiter,
eval_data=test_dataiter,
eval_metric="accuracy",
batch_end_callback=mx.callback.Speedometer(batch_size))#开跑数据。

  

从零开始学习MXnet(一)的更多相关文章

  1. 从零开始学习MXnet(四)计算图和粗细粒度以及自动求导

    这篇其实跟使用MXnet的关系不大,但对于我们理解深度学习的框架设计还是很有帮助的. 首先还是对promgramming models的一个简单介绍,这个东西实际上是在编译里面经常出现的东西,我们在编 ...

  2. 从零开始学习MXnet(五)MXnet的黑科技之显存节省大法

    写完发现名字有点拗口..- -# 大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉 ...

  3. 从零开始学习MXnet(三)之Model和Module

    在我们在MXnet中定义好symbol.写好dataiter并且准备好data之后,就可以开开心的去训练了.一般训练一个网络有两种常用的策略,基于model的和基于module的.今天,我想谈一谈他们 ...

  4. 从零开始学习MXnet(二)之dataiter

    MXnet的设计结构是C++做后端运算,python.R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面.今天我们先谈一谈Data iter ...

  5. ASP.NET从零开始学习EF的增删改查

           ASP.NET从零开始学习EF的增删改查           最近辞职了,但是离真正的离职还有一段时间,趁着这段空档期,总想着写些东西,想来想去,也不是很明确到底想写个啥,但是闲着也是够 ...

  6. 从零开始学习jQuery (五) 事件与事件对象

    本系列文章导航 从零开始学习jQuery (五) 事件与事件对象 一.摘要 事件是脚本编程的灵魂. 所以本章内容也是jQuery学习的重点. 本文将对jQuery中的事件处理以及事件对象进行详细的讲解 ...

  7. 从零开始学习jQuery (四) 使用jQuery操作元素的属性与样式

    本系列文章导航 从零开始学习jQuery (四) 使用jQuery操作元素的属性与样式 一.摘要 本篇文章讲解如何使用jQuery获取和操作元素的属性和CSS样式. 其中DOM属性和元素属性的区分值得 ...

  8. 从零开始学习jQuery (三) 管理jQuery包装集

    本系列文章导航 从零开始学习jQuery (三) 管理jQuery包装集 一.摘要 在使用jQuery选择器获取到jQuery包装集后, 我们需要对其进行操作. 本章首先讲解如何动态的创建元素, 接着 ...

  9. 从零开始学习jQuery (二) 万能的选择器

    本系列文章导航 从零开始学习jQuery (二) 万能的选择器 一.摘要 本章讲解jQuery最重要的选择器部分的知识. 有了jQuery的选择器我们几乎可以获取页面上任意的一个或一组对象, 可以明显 ...

随机推荐

  1. Python的matplotlib模块的使用-Github仓库

    import matplotlib.pyplot as plt import numpy as np import requests url='https://api.github.com/searc ...

  2. Python3爬虫(九) 数据存储之关系型数据库MySQL

    Infi-chu: http://www.cnblogs.com/Infi-chu/ 关系型数据库关系型数据库是基于关系模型的数据库,而关系模型是通过二维表来保存的,所以关系型数据库的存储方式就是行列 ...

  3. ListView学习

    ListView类 常用的基本属性 FullRowSelect:设置是否行选择模式.(默认为false)提示:只有在Details视图,该属性有效. GridLines:设置行和列之间是否显示网格线. ...

  4. R语言学习笔记(七): 排序函数:sort(), rank(), order()

    sort() sort()函数直接对函数进行排序,并返回排序结果. > a <- c(12,4,6,5) > sort(a) [1] 4 5 6 12 rank() rank()函数 ...

  5. xampps 不能配置非安装目录虚拟主机解决方案

    今天将前几天安装好的xampps配置下,准备开始php开发之旅,在我信心满满的将工作目录定在非安装目录上(安装目录在:D:\Program Files\xampps\apache\htdocs  我将 ...

  6. 【转】Ubuntu 14.04下Django+MySQL安装部署全过程

    一.简要步骤.(阿里云Ubuntu14.04) Python安装 Django Mysql的安装与配置 记录一下我的部署过程,也方便一些有需要的童鞋,大神勿喷~ 二.Python的安装 由于博主使用的 ...

  7. mvc4 Forms验证存储 两种登录代码

    自己也不知道网上看到的第一种居多,第二种用到的人很少,第二种代码十分简洁,就是不清楚是否有安全隐患. 要采用Forms身份验证,先要在应用程序根目录中的Web.config中做相应的设置: <a ...

  8. Jmeter学习(三)

    Apache JMeter是Apache组织开发的基于Java的压力测试工具.用于对软件做压力测试,它最初被设计用于Web应用测试,但后来扩展到其他测试.(来自百度) jmeter的特点: 开源免费. ...

  9. HDFS常用文件操作

    put 上传文件    hadoop fs -put wordcount.txt   /data/wordcount/ text 查看文件内容   hadoop fs -text /output/wo ...

  10. static 关键字解析(转)

    static关键字解析   Java中的static关键字解析 static关键字是很多朋友在编写代码和阅读代码时碰到的比较难以理解的一个关键字,也是各大公司的面试官喜欢在面试时问到的知识点之一.下面 ...