1. 导入各种包

from mxnet import gluon
from mxnet.gluon import nn
import matplotlib.pyplot as plt
from mxnet import autograd as autograd
from mxnet import nd
import mxnet as mx
from collections import namedtuple
import random

2. 准备数据

使用和mnist很像的FashionMNIST数据集,使用Gluon下载

def transform(data,label):
return data.astype('float32')/255,label.astype('float32')
fashion_train = gluon.data.vision.FashionMNIST(root='./',train=True,transform=transform)
fashion_test = gluon.data.vision.FashionMNIST(root='./',train=True, transform=transform)
batch_size = 256
train_data = gluon.data.DataLoader(fashion_train,batch_size,shuffle=True)
test_data = gluon.data.DataLoader(fashion_test,batch_size,shuffle=True)

用于显示图像和标签

def show_images(images):
n = images.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
figs[i].axes.get_xaxis().set_visible(False)
figs[i].axes.get_yaxis().set_visible(False)
plt.show()
def get_text_labels(label):
text_labels = [
't-shirt', 'trouser', 'pullover', 'dress,', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
]
return [text_labels[int(i)] for i in label]

看下数据集长啥样

data,label = fashion_train[5:19]
show_images(data)
print(get_text_labels(label))

['coat', 'coat', 'sandal', 'coat', 'bag', 't-shirt', 'bag', 'ankle boot', 't-shirt', 'pullover', 'pullover', 'ankle boot', 'dress,', 'dress,']

3. 精度计算函数

def accuracy(output, label):
return nd.mean(output.argmax(axis=1)==label).asscalar() def evaluate_accuracy(data_iterator, net):
acc = 0.
for data, label in data_iterator:
output = net(nd.transpose(data,(0,3,1,2)))
acc += accuracy(output, label)
return acc / len(data_iterator)

4. 定义网络

4.1 自己定义的层

Gluon模型转到Symbol下只能用HybridSequential模式,HybridSequential是静态图,会对计算有优化,不过HybridSequentialSequential可以很方便的转换,确切的就是一行代码的事。同样自定义的网络,要使用HybridBlock,和Block没有多大区别

class MyDense(nn.HybridBlock):
def __init__(self,**kwargs):
super(MyDense,self).__init__(**kwargs)
with self.name_scope():
self.dense0 = nn.Dense(256)
self.dense1 = nn.Dense(10)
def hybrid_forward(self,F,x): # 这里要使用hybrid_forward而不是forward,并且多了个参数F
return self.dense1(F.relu(self.dense0(x))) # F的作用就是替代 nd,如果是静态图,就是用 sym,否则使用 nd

4.2 使用自定义的层和自带的层组成完整的网络

网络定义和动态图一样,只不过把Sequential替换成了HybridSequential,在最后使用hybridize()会对静态图进行优化

net = nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=3, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(MyDense())
net.initialize(init=mx.init.Xavier())
net.hybridize()
net
HybridSequential(
(0): Conv2D(20, kernel_size=(5, 5), stride=(1, 1))
(1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(2): Conv2D(50, kernel_size=(3, 3), stride=(1, 1))
(3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(4): Flatten
(5): MyDense(
(dense0): Dense(256, linear)
(dense1): Dense(10, linear)
)
)

5. 训练

使用Adam优化算法,训练的速度会快点

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'Adam', {'learning_rate': 0.008})
for epoch in range(5):
train_loss = 0.
train_acc = 0.
test_acc = 0.
for data, label in train_data:
data = nd.transpose(data,(0,3,1,2))
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label) test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))
Epoch 0. Loss: 0.498041, Train acc 0.817226, Test acc 0.865459
Epoch 1. Loss: 0.312128, Train acc 0.884813, Test acc 0.894265
Epoch 2. Loss: 0.274009, Train acc 0.898454, Test acc 0.898604
Epoch 3. Loss: 0.247741, Train acc 0.906521, Test acc 0.914910
Epoch 4. Loss: 0.226967, Train acc 0.913736, Test acc 0.914334

6. 保存成Symbol格式的网络和参数(重点)

要注意保存网络参数的时候,需要net.collect_params().save()这样保存,而不是net.save_params()保存

最新版的mxnet已经有可以导出到symbol格式下的接口了。需要mxnet版本在20171015以上

下面示例代码也已经改成新版的保存,加载方式

#新版本的保存方式
net.export('Gluon_FashionMNIST')

7. 使用Symbol加载网络并绑定

symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 1, 28, 28))])
mod.load_params('Gluon_FashionMNIST-0000.params')
Batch = namedtuple('Batch', ['data'])

8. 预测试试看效果

img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1) plt.imshow(img.reshape((28, 28)).asnumpy())
plt.axis('off')
plt.show()
print('predicted labels:',get_text_labels(predicted_labels.asnumpy())) print('true labels:',get_text_labels([label]))

predicted labels: ['pullover']
true labels: ['pullover']

MxNet新前端Gluon模型转换到Symbol的更多相关文章

  1. 使用MxNet新接口Gluon提供的预训练模型进行微调

    1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...

  2. 前端MVVM框架avalon - 模型转换1

    轻量级前端MVVM框架avalon - 模型转换(一) 接上一章 ViewModel modelFactory工厂是如何加工用户定义的VM? 附源码 洋洋洒洒100多行内部是魔幻般的实现 1: fun ...

  3. 轻量级前端MVVM框架avalon - 模型转换

    接上一章 ViewModel modelFactory工厂是如何加工用户定义的VM? 附源码 洋洋洒洒100多行内部是魔幻般的实现 1: function modelFactory(scope) { ...

  4. 混合前端seq2seq模型部署

    混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...

  5. 【模型推理】Tengine 模型转换及量化

      欢迎关注我的公众号 [极智视界],回复001获取Google编程规范   O_o   >_<   o_O   O_o   ~_~   o_O   本文介绍一下 Tengine 模型转换 ...

  6. 将List 中的ConvertAll的使用:List 中的元素转换,List模型转换, list模型转数组

    一,直接入代码 using System; using System.Collections.Generic; using System.Linq; using System.Web; using S ...

  7. Blazor——Asp.net core的新前端框架

    原文:Blazor--Asp.net core的新前端框架 Blazor是微软在Asp.net core 3.0中推出的一个前端MVVM模型,它可以利用Razor页面引擎和C#作为脚本语言来构建WEB ...

  8. 看JQ时代过来的前端,如何转换思路用Vue打造选项卡组件

    前言 在Vue还未流行的时候,我们都是用JQuery来封装一个选项卡插件,如今Vue当道,让我们一起来看看从JQ时代过来的前端是如何转换思路,用数据驱动DOM的思想打造一个Vue选项卡组件. 接下来, ...

  9. tensorflow,object,detection,在model zoom,新下载的模型,WARNING:root:Variable [resnet_v1_50/block1/unit_3/bottleneck_v1/conv3/BatchNorm/gamma] is not available in checkpoint

    现象: WARNING:root:Variable [resnet_v1_50/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta] is not ava ...

随机推荐

  1. JSONP(Json with padding)

    JSONP:一种非官方跨域数据交互协议 JSONP怎么产生的 JSONP的原理 看上面的来源加以理解 上面说过了,script是不受跨域影响的 那么我们可以在我们代码中引用B服务器的文件 <sc ...

  2. 团队作业4——第一次项目冲刺 SeCOnd DaY

    项目冲刺--Double Kill 喂喂喂,你好你好,听得见吗?这里是天霸动霸.tua广播站,我是主播小学生¥-¥ 第一次敏捷冲刺平稳的度过了第一天,第一天的任务大家也圆满完成啦[拍手庆祝],那么今天 ...

  3. 团队作业8----第二次项目冲刺(Beta阶段) 第六天

    BETA阶段冲刺第六天 1.小会议ing 2.每个人的工作 (1)昨天已完成的工作 重复部分可以用红色字体显示 (2) 今天计划完成的工作 (3) 工作中遇到的困难: 尤少辉:在测试的时候,当队友提出 ...

  4. 201521123023《Java程序设计》第6周学习总结

    1. 本周学习总结 2. 书面作业 1.clone方法 1.1 Object对象中的clone方法是被protected修饰,在自定义的类中覆盖clone方法时需要注意什么? 实现克隆必须实现Clon ...

  5. java课程设计-算术运算测试

    1. 团队名称.团队成员介绍 团队名称:cococo 团队成员 组长:网络1514叶城龙 201521123109 组员:网络1514余腾鑫 201521123108 2. 项目git地址 http: ...

  6. latch session allocation

    应用反馈上午10点左右出现大量应用连接数据库报错 采集9点-10点和10点-11点的AWR报告进行分析 DB时间明显差异,再继续分析等待事件 可以看出有session相关的Latch等待事件,查看相关 ...

  7. 深度学习(一)cross-entropy softmax overfitting regularization dropout

    一.Cross-entropy 我们理想情况是让神经网络学习更快 假设单模型: 只有一个输入,一个神经元,一个输出   简单模型: 输入为1时, 输出为0 神经网络的学习行为和人脑差的很多, 开始学习 ...

  8. SSH第一篇【整合SSH步骤、OpenSessionInView】

    前言 到目前为止,Struts2.Hibernate.Spring框架都过了一遍了.也写过了Spring怎么与Struts2整合,Spring与Hibernate整合-本博文主要讲解SSH的整合 整合 ...

  9. 都是Javascript的作用域惹得祸

    案件重现 今天有位然之OA 系统的定制开发用户咨询了个问题,他想在新加的功能模块的操作面板中,实现用户点击删除按钮时提示友好提醒,如下: 问题很简单,虽然他自己最终达到目的效果了,但不知道起初问题出在 ...

  10. Vuforia开发完全指南---License Manager和Target Manager详解

    License Manager和Target Manager License Manager 对于每一个用Vuforia开发的AR程序来说,都有一个唯一的license key,在Unity中必须首先 ...