1. 导入各种包

from mxnet import gluon
from mxnet.gluon import nn
import matplotlib.pyplot as plt
from mxnet import autograd as autograd
from mxnet import nd
import mxnet as mx
from collections import namedtuple
import random

2. 准备数据

使用和mnist很像的FashionMNIST数据集,使用Gluon下载

def transform(data,label):
return data.astype('float32')/255,label.astype('float32')
fashion_train = gluon.data.vision.FashionMNIST(root='./',train=True,transform=transform)
fashion_test = gluon.data.vision.FashionMNIST(root='./',train=True, transform=transform)
batch_size = 256
train_data = gluon.data.DataLoader(fashion_train,batch_size,shuffle=True)
test_data = gluon.data.DataLoader(fashion_test,batch_size,shuffle=True)

用于显示图像和标签

def show_images(images):
n = images.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
figs[i].axes.get_xaxis().set_visible(False)
figs[i].axes.get_yaxis().set_visible(False)
plt.show()
def get_text_labels(label):
text_labels = [
't-shirt', 'trouser', 'pullover', 'dress,', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
]
return [text_labels[int(i)] for i in label]

看下数据集长啥样

data,label = fashion_train[5:19]
show_images(data)
print(get_text_labels(label))

['coat', 'coat', 'sandal', 'coat', 'bag', 't-shirt', 'bag', 'ankle boot', 't-shirt', 'pullover', 'pullover', 'ankle boot', 'dress,', 'dress,']

3. 精度计算函数

def accuracy(output, label):
return nd.mean(output.argmax(axis=1)==label).asscalar() def evaluate_accuracy(data_iterator, net):
acc = 0.
for data, label in data_iterator:
output = net(nd.transpose(data,(0,3,1,2)))
acc += accuracy(output, label)
return acc / len(data_iterator)

4. 定义网络

4.1 自己定义的层

Gluon模型转到Symbol下只能用HybridSequential模式,HybridSequential是静态图,会对计算有优化,不过HybridSequentialSequential可以很方便的转换,确切的就是一行代码的事。同样自定义的网络,要使用HybridBlock,和Block没有多大区别

class MyDense(nn.HybridBlock):
def __init__(self,**kwargs):
super(MyDense,self).__init__(**kwargs)
with self.name_scope():
self.dense0 = nn.Dense(256)
self.dense1 = nn.Dense(10)
def hybrid_forward(self,F,x): # 这里要使用hybrid_forward而不是forward,并且多了个参数F
return self.dense1(F.relu(self.dense0(x))) # F的作用就是替代 nd,如果是静态图,就是用 sym,否则使用 nd

4.2 使用自定义的层和自带的层组成完整的网络

网络定义和动态图一样,只不过把Sequential替换成了HybridSequential,在最后使用hybridize()会对静态图进行优化

net = nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=3, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(MyDense())
net.initialize(init=mx.init.Xavier())
net.hybridize()
net
HybridSequential(
(0): Conv2D(20, kernel_size=(5, 5), stride=(1, 1))
(1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(2): Conv2D(50, kernel_size=(3, 3), stride=(1, 1))
(3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(4): Flatten
(5): MyDense(
(dense0): Dense(256, linear)
(dense1): Dense(10, linear)
)
)

5. 训练

使用Adam优化算法,训练的速度会快点

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'Adam', {'learning_rate': 0.008})
for epoch in range(5):
train_loss = 0.
train_acc = 0.
test_acc = 0.
for data, label in train_data:
data = nd.transpose(data,(0,3,1,2))
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label) test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))
Epoch 0. Loss: 0.498041, Train acc 0.817226, Test acc 0.865459
Epoch 1. Loss: 0.312128, Train acc 0.884813, Test acc 0.894265
Epoch 2. Loss: 0.274009, Train acc 0.898454, Test acc 0.898604
Epoch 3. Loss: 0.247741, Train acc 0.906521, Test acc 0.914910
Epoch 4. Loss: 0.226967, Train acc 0.913736, Test acc 0.914334

6. 保存成Symbol格式的网络和参数(重点)

要注意保存网络参数的时候,需要net.collect_params().save()这样保存,而不是net.save_params()保存

最新版的mxnet已经有可以导出到symbol格式下的接口了。需要mxnet版本在20171015以上

下面示例代码也已经改成新版的保存,加载方式

#新版本的保存方式
net.export('Gluon_FashionMNIST')

7. 使用Symbol加载网络并绑定

symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 1, 28, 28))])
mod.load_params('Gluon_FashionMNIST-0000.params')
Batch = namedtuple('Batch', ['data'])

8. 预测试试看效果

img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1) plt.imshow(img.reshape((28, 28)).asnumpy())
plt.axis('off')
plt.show()
print('predicted labels:',get_text_labels(predicted_labels.asnumpy())) print('true labels:',get_text_labels([label]))

predicted labels: ['pullover']
true labels: ['pullover']

MxNet新前端Gluon模型转换到Symbol的更多相关文章

  1. 使用MxNet新接口Gluon提供的预训练模型进行微调

    1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...

  2. 前端MVVM框架avalon - 模型转换1

    轻量级前端MVVM框架avalon - 模型转换(一) 接上一章 ViewModel modelFactory工厂是如何加工用户定义的VM? 附源码 洋洋洒洒100多行内部是魔幻般的实现 1: fun ...

  3. 轻量级前端MVVM框架avalon - 模型转换

    接上一章 ViewModel modelFactory工厂是如何加工用户定义的VM? 附源码 洋洋洒洒100多行内部是魔幻般的实现 1: function modelFactory(scope) { ...

  4. 混合前端seq2seq模型部署

    混合前端seq2seq模型部署 本文介绍,如何将seq2seq模型转换为PyTorch可用的前端混合Torch脚本.要转换的模型来自于聊天机器人教程Chatbot tutorial. 1.混合前端 在 ...

  5. 【模型推理】Tengine 模型转换及量化

      欢迎关注我的公众号 [极智视界],回复001获取Google编程规范   O_o   >_<   o_O   O_o   ~_~   o_O   本文介绍一下 Tengine 模型转换 ...

  6. 将List 中的ConvertAll的使用:List 中的元素转换,List模型转换, list模型转数组

    一,直接入代码 using System; using System.Collections.Generic; using System.Linq; using System.Web; using S ...

  7. Blazor——Asp.net core的新前端框架

    原文:Blazor--Asp.net core的新前端框架 Blazor是微软在Asp.net core 3.0中推出的一个前端MVVM模型,它可以利用Razor页面引擎和C#作为脚本语言来构建WEB ...

  8. 看JQ时代过来的前端,如何转换思路用Vue打造选项卡组件

    前言 在Vue还未流行的时候,我们都是用JQuery来封装一个选项卡插件,如今Vue当道,让我们一起来看看从JQ时代过来的前端是如何转换思路,用数据驱动DOM的思想打造一个Vue选项卡组件. 接下来, ...

  9. tensorflow,object,detection,在model zoom,新下载的模型,WARNING:root:Variable [resnet_v1_50/block1/unit_3/bottleneck_v1/conv3/BatchNorm/gamma] is not available in checkpoint

    现象: WARNING:root:Variable [resnet_v1_50/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta] is not ava ...

随机推荐

  1. !JS实战之随机像素图

    JavaScript实例分享之----画随机像素图.随机像素图(作者自己取得名字)指的是一张图片上每一个像素的颜色都是随机的.此时应该能联想到这幅图多么眼花缭乱,好吧,我们用JS来实现它的原因是JS很 ...

  2. ★不容错过的PPT教程!

    IT工程师必须学会的计算机基础应用之一--PPT! 28项大神级PPT制作技术,学会后让你变成PPT高手哦!更多实用教程,请关注@IT工程师 !

  3. 【1414软工助教】团队作业9——测试与发布(Beta版本) 得分榜

    题目 团队作业9--测试与发布(Beta版本) 往期成绩 个人作业1:四则运算控制台 结对项目1:GUI 个人作业2:案例分析 结对项目2:单元测试 团队作业1:团队展示 团队作业2:需求分析& ...

  4. 201521123003《Java程序设计》第7周学习总结

    1. 本周学习总结 以你喜欢的方式(思维导图或其他)归纳总结集合相关内容. 参考资料: XMind 2. 书面作业 Q1.ArrayList代码分析 1.1 解释ArrayList的contains源 ...

  5. 201521123051《Java程序设计》第七周学习总结

    1. 本周学习总结 以你喜欢的方式(思维导图或其他)归纳总结集合相关内容. 使用工具:百度脑图 2. 书面作业 1.ArrayList代码分析 1.1 解释ArrayList的contains源代码 ...

  6. 201521123031 《Java程序设计》第6周学习总结

    1. 本周学习总结 1.1 面向对象学习暂告一段落,请使用思维导图,以封装.继承.多态为核心概念画一张思维导图,对面向对象思想进行一个总结. 注1:关键词与内容不求多,但概念之间的联系要清晰,内容覆盖 ...

  7. 201521123110《Java程序设计》第5周学习总结

    1. 本周学习总结 1.1 尝试使用思维导图总结有关多态与接口的知识点. 2. 书面作业 1.代码阅读:Child压缩包内源代码 1.1 com.parent包中Child.java文件能否编译通过? ...

  8. 201521123088《Java程序设计》第12周学习总结

    1. 本周学习总结 1.1 以你喜欢的方式(思维导图或其他)归纳总结多流与文件相关内容. 2. 书面作业 将Student对象(属性:int id, String name,int age,doubl ...

  9. 子元素设定margin值会影响父元素

    有些情况下,我们设定父元素下的子元素margin值时,父元素会被影响. 这是个常见问题,而且只在标准浏览器下 (FirfFox.Chrome.Opera.Sarfi)产生问题,IE下反而表现良好. 例 ...

  10. Linux 安装 mysql 并配置

    1.下载 下载地址:http://dev.mysql.com/downloads/mysql/5.6.html#downloads 下载版本:我这里选择的5.6.33,通用版,linux下64位 也可 ...