https://mxnet.incubator.apache.org/tutorials/basic/module.html

import logging
import random
logging.getLogger().setLevel(logging.INFO) import mxnet as mx
import numpy as np mx.random.seed(1234)
np.random.seed(1234)
random.seed(1234) # 准备数据
fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data')
data = np.genfromtxt(fname=fname,delimiter=',')[:,1:]
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')]) batch_size = 32
ntrain = int(data.shape[0]*0.8) train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size) # 定义网络
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"}) # # 创建模块
mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label']) # # 中层接口
# # 训练模型
# mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)
# mod.init_params(initializer=mx.init.Uniform(scale=.1))
# mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),))
# metric = mx.metric.create('acc')
#
# for epoch in range(100):
# train_iter.reset()
# metric.reset()
# for batch in train_iter:
# mod.forward(batch,is_train=True)
# mod.update_metric(metric,batch.label)
# mod.backward()
# mod.update()
# print('Epoch %d,Training %s' % (epoch,metric.get())) # fit 高层接口
train_iter.reset()
mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label']) mod.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
num_epoch=10) # 预测和评估
y = mod.predict(val_iter)
assert y.shape == (4000,26) # 评分
score = mod.score(val_iter,['acc'])
print("Accuracy score is %f"%(score[0][1]))
assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1] # 保存和加载
# 构造一个回调函数保存检查点
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix) mod = mx.mod.Module(symbol=net)
mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint) sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson() # assign the loaded parameters to the module
mod.set_params(arg_params, aux_params) mod = mx.mod.Module(symbol=sym)
mod.fit(train_iter,
num_epoch=21,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=3)
assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

mxnet 神经网络训练和预测的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  2. 利用Matlab神经网络计算包预测近四天除湖北外新增确诊人数:拐点已现

    数据来源: 国家卫健委 已经7连降咯! 1.20-2.10图示(更新中): 神经网络训练并预测数据: clear %除湖北以外全国新增确诊病例数 2020.1.20-2.9 num=[5,44,62, ...

  3. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  4. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  5. mxnet的训练过程——从python到C++

    mxnet的训练过程--从python到C++ mxnet(github-mxnet)的python接口相当完善,我们可以完全不看C++的代码就能直接训练模型,如果我们要学习它的C++的代码,从pyt ...

  6. 神经网络训练tricks

    神经网络构建好,训练不出好的效果怎么办?明明说好的拟合任意函数(一般连续)(为什么?可以参考http://neuralnetworksanddeeplearning.com/),说好的足够多的数据(h ...

  7. tesorflow - create neural network+结果可视化+加速神经网络训练+Optimizer+TensorFlow

    以下仅为了自己方便查看,绝大部分参考来源:莫烦Python,建议去看原博客 一.添加层 def add_layer() 定义 add_layer()函数 在 Tensorflow 里定义一个添加层的函 ...

  8. TensorFlow实战第三课(可视化、加速神经网络训练)

    matplotlib可视化 构件图形 用散点图描述真实数据之间的关系(plt.ion()用于连续显示) # plot the real data fig = plt.figure() ax = fig ...

  9. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

随机推荐

  1. Cannot find module 'object-keys' 的解决办法

    把node_modules文件夹删除,重新cnpm install安装node_modules就好了.

  2. js 常用脚本

    1.判断电话号码和手机号码 var tel = $("#tel").val(); if (isNotBlank($.trim(tel))) { //不为空的情况下判断符合手机号码标 ...

  3. PHP学习3——数组

    主要内容: 简介 常用的方法 循环遍历数组 PHP预定义数组 数组的处理函数 数组 PHP由于是弱类型的语言,他的变量类型是可以自由变换的,他的数组很自由,长度是可以动态增加的. 他的索引默认为数字0 ...

  4. in和not in

    当子查询返回的列的值是多个值,那么就不能使用比较运算符(> < = !=),使用关键字in 语法: select …..from …..where 表达式 in (子查询) 常用in替换等 ...

  5. js返回树形结构数据

    /** * 树形结构转换 * @param a * @param idStr * @param pidStr * @param chindrenStr * @returns {Array} */ fu ...

  6. [生活] 日常英语学习笔记-NEVER HAVE I EVER游戏

    逛油管,看视频,学英语. 大家要过周末了说啥 Happy Sunday Have a restful  Sunday 有个空闲的周末 我们正在看电影 We are watching movie it ...

  7. MVC设计模式实现权限管理登录,超详细

    功能实现:在页面输入给定的用户名之一,可以显示当前用户的权限,也可以在页面更改该用户的权限,更新之后保存.像下面这样. 填写用户名提交: 显示用户AAA的权限: 修改权限(增加article3): 点 ...

  8. JRebel&XRebel

    介绍==>>>> JRebel&XRebel官网 https://zeroturnaround.com/HotSwap和JRebel原理 http://www.holl ...

  9. 微信公众号开发《三》微信JS-SDK之地理位置的获取与在线导航,集成百度地图实现在线地图搜索

    本次讲解微信开发第三篇:获取用户地址位置信息,是非常常用的功能,特别是服务行业公众号,尤为需要该功能,本次讲解的就是如何调用微信JS-SDK接口,获取用户位置信息,并结合百度地铁,实现在线地图搜索,与 ...

  10. Vue中的静态资源管理(src下的assets和static文件夹的区别)

    ### 你可能注意到了我们的静态资源共有两个目录src/assets和static/,你们它们之间有怎样的区别呢? 资源打包 为了回答这个问题,我们需要了解webpack是如何处理静态资源的. 在所有 ...