最近在摸mxnet和tensorflow。两个我都搭起来了。tensorflow跑了不少代码,总的来说用得比较顺畅,文档很丰富,api熟悉熟悉写代码没什么问题。

今天把两个平台做了一下对比。同是跑mnist,tensorflow 要比mxnet 慢一二十倍。mxnet只需要半分钟,tensorflow跑了13分钟。

在mxnet中如何开跑?

cd /mxnet/example/image-classification
python train_mnist.py 我用的是最新的mxnet版本。运行脚本它会自动下载数据集。
然后刷刷刷的刷屏了。
我们来看看这个脚本如何写的,从而建立mxnet编程思路:
import find_mxnet
import mxnet as mx
import argparse
import os, sys
import train_model def _download(data_dir):
    if not os.path.isdir(data_dir):
        os.system("mkdir " + data_dir)
    os.chdir(data_dir)
    if (not os.path.exists('train-images-idx3-ubyte')) or \
       (not os.path.exists('train-labels-idx1-ubyte')) or \
       (not os.path.exists('t10k-images-idx3-ubyte')) or \
       (not os.path.exists('t10k-labels-idx1-ubyte')):
        os.system("wget http://data.dmlc.ml/mxnet/data/mnist.zip")
        os.system("unzip -u mnist.zip; rm mnist.zip")
    os.chdir("..") def get_loc(data, attr={'lr_mult':'0.01'}):
    """
    the localisation network in lenet-stn, it will increase acc about more than 1%,
    when num-epoch >=15
    """
    loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))
    loc = mx.symbol.Activation(data = loc, act_type='relu')
    loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max')
    loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
    loc = mx.symbol.Activation(data = loc, act_type='relu')
    loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg')
    loc = mx.symbol.Flatten(data=loc)
    loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)
    return loc def get_mlp():
    """
    multi-layer perceptron
    """
    data = mx.symbol.Variable('data')
    fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
    fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
    fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
    mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
    return mlp def get_lenet(add_stn=False):
    """
    LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
    Haffner. "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE (1998)
    """
    data = mx.symbol.Variable('data')
    if(add_stn):
        data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
                                         transform_type="affine", sampler_type="bilinear")
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
    tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
    tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
                              kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
    tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
    # second fullc
    fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
    # loss
    lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
    return lenet def get_iterator(data_shape):
    def get_iterator_impl(args, kv):
        data_dir = args.data_dir
        if '://' not in args.data_dir:
            _download(args.data_dir)
        flat = False if len(data_shape) == 3 else True         train           = mx.io.MNISTIter(
            image       = data_dir + "train-images-idx3-ubyte",
            label       = data_dir + "train-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            shuffle     = True,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)         val = mx.io.MNISTIter(
            image       = data_dir + "t10k-images-idx3-ubyte",
            label       = data_dir + "t10k-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)         return (train, val)
    return get_iterator_impl def parse_args():
    parser = argparse.ArgumentParser(description='train an image classifer on mnist')
    parser.add_argument('--network', type=str, default='mlp',
                        choices = ['mlp', 'lenet', 'lenet-stn'],
                        help = 'the cnn to use')
    parser.add_argument('--data-dir', type=str, default='mnist/',
                        help='the input data directory')
    parser.add_argument('--gpus', type=str,
                        help='the gpus will be used, e.g "0,1,2,3"')
    parser.add_argument('--num-examples', type=int, default=60000,
                        help='the number of training examples')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='the batch size')
    parser.add_argument('--lr', type=float, default=.1,
                        help='the initial learning rate')
    parser.add_argument('--model-prefix', type=str,
                        help='the prefix of the model to load/save')
    parser.add_argument('--save-model-prefix', type=str,
                        help='the prefix of the model to save')
    parser.add_argument('--num-epochs', type=int, default=10,
                        help='the number of training epochs')
    parser.add_argument('--load-epoch', type=int,
                        help="load the model on an epoch using the model-prefix")
    parser.add_argument('--kv-store', type=str, default='local',
                        help='the kvstore type')
    parser.add_argument('--lr-factor', type=float, default=1,
                        help='times the lr with a factor for every lr-factor-epoch epoch')
    parser.add_argument('--lr-factor-epoch', type=float, default=1,
                        help='the number of epoch to factor the lr, could be .5')
    return parser.parse_args() if __name__ == '__main__':
    args = parse_args()     if args.network == 'mlp':
        data_shape = (784, )
        net = get_mlp()
    elif args.network == 'lenet-stn':
        data_shape = (1, 28, 28)
        net = get_lenet(True)
    else:
        data_shape = (1, 28, 28)
        net = get_lenet()     # train
    train_model.fit(args, net, get_iterator(data_shape)) 先看Main函数,就是读配置参数,读网络结构,包括设置数据的大小,然后就是调用已有的包train_model。然后传入这之前设置的三个参数。就开始训练了。
编程架构也蛮清晰的。模块化也搞的好。
接着看看参数设置问题。参数导入了很多配置文件,基本上caffe中的Proto都在这个里面设置了。包括数据集地址,批大小,学习率,损失函数,等等。然后看看读网络结构,
读网络结构就是在一层一层的搭积木,根据之前读入的配置文件或者自己定义一些参数。搭好积木就开始训练了。
caffe的一个缺点是不够灵活,毕竟不是自己写代码,只是写配置文件,总感觉受制于人。mxnet和tensorflow就比较方便,提供api,你可以按你的方式来调用和定义
网络结构。总的说来,其实是后两个框架模块化做的好,提供底层的api支持你写自己的网络。caffe要自己写网络层的话还是很费劲的

mxnet实战系列(一)入门与跑mnist数据集的更多相关文章

  1. Spring全家桶系列–[SpringBoot入门到跑路]

    //本文作者:cuifuan Spring全家桶————[SpringBoot入门到跑路] 对于之前的Spring框架的使用,各种配置文件XML.properties一旦出错之后错误难寻,这也是为什么 ...

  2. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  3. 关于入门深度学习mnist数据集前向计算的记录

    import osimport lr as lrimport tensorflow as tffrom pyspark.sql.functions import stddevfrom tensorfl ...

  4. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  5. Spark入门实战系列--6.SparkSQL(中)--深入了解SparkSQL运行计划及调优

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.1  运行环境说明 1.1.1 硬软件环境 线程,主频2.2G,10G内存 l  虚拟软 ...

  6. Spark入门实战系列--10.分布式内存文件系统Tachyon介绍及安装部署

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .Tachyon介绍 1.1 Tachyon简介 随着实时计算的需求日益增多,分布式内存计算 ...

  7. MXNet学习~第一个例子~跑MNIST

    反正基本上是给自己看的,直接贴写过注释后的代码,可能有的地方理解不对,你多担待,看到了也提出来(基本上对未来的自己说的),三层跑到了97%,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导 ...

  8. Spark入门实战系列--1.Spark及其生态圈简介

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .简介 1.1 Spark简介 年6月进入Apache成为孵化项目,8个月后成为Apache ...

  9. Spark入门实战系列--2.Spark编译与部署(上)--基础环境搭建

    [注] 1.该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取: 2.Spark编译与部署将以CentOS 64位操作系统为基础,主要是考虑到实际应用 ...

随机推荐

  1. Javascript模式(第二章基本技巧)------读书笔记

    本章主要帮助大家写出高质量的JS代码的方法,模式和习惯,例如:避免使用全局变量,使用单个的var变量声明,缓存for循环的长度变量length等 一.尽量避免使用全局变量 1 每一个js环境都有一个全 ...

  2. Jeesite的cahche工具类

    本CacheUtils主要是基于shiro的cache进行处理. 其他选择: 类似的我们可以选择java cache ,spring cahche等方案.                   再进一步 ...

  3. visual studio 2013快捷键与2012不同

    升级了Visual Studio2013后发现有些快捷键不能使用,于是自己尝试设置找回,还真给发现了: 依次选择(工具-->选项-->环境-->键盘)把映射方案改成Visual C# ...

  4. servlet简单用法和配置示例及说明

    学习原因和目的:   我如今所接触的项目都是bs模式的web应用,而里边基本上都是用的spring MVC和前台交互,servlet貌似用的很少.   但是即便是用spring和spring MVC, ...

  5. 打造H5里的“3D全景漫游”秘籍

    近来风生水起的VR虚拟现实技术,抽空想起年初完成的“星球计划”项目,总结篇文章与各位分享一下制作基于Html5的3D全景漫游秘籍. QQ物联与深圳市天文台合作,在手Q“发现新设备”-“公共设备”里,连 ...

  6. jvm是如何管理内存的

    1.JVM是如何管理内存的 Java中,内存管理是JVM自动进行的,无需人为干涉. 了解Java内存模型看这里:java内存模型是什么样的 了解jvm实例结构看这里:jvm实例的结构是什么样的 创建对 ...

  7. 在VS中添加lib的简单方法

    1.工程---属性---配置属性---VC++ Directories---General---Include Directories:加上头文件存放目录 2.VC中,切换到"解决方案视图& ...

  8. TCP/IP详解 (转)

    TCP/IP详解学习笔记(1)-基本概念 为什么会有TCP/IP协议 在世界上各地,各种各样的电脑运行着各自不同的操作系统为大家服务,这些电脑在表达同一种信息的时候所使用的方法是千差万别.就好像圣经中 ...

  9. 解决java文件编码和windows7系统(中文版)默认编码冲突所导致的乱码情况

    开篇从一个比较简单但是也比较蛋疼的问题开始吧. 背景介绍:我是新手小白,初学java. 问题介绍:在使用UTF-8编码格式写java文件时,编译出现问题. 原因分析:1.java文件的编码格式是UTF ...

  10. xmpp关于后台挂起的消息接收,后台消息推送,本地发送通知

    想问下,在xmpp即时通讯的项目中,我程序如果挂起了,后台有消息过来,我这边的推送不过来,所以我的通知就会收不到消息,当我重新唤醒应用的时候,他才会接收到通知,消息就会推送过来,我在plist哪里设置 ...