1. 导入各种包

from mxnet import gluon
import mxnet as mx
from mxnet.gluon import nn
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import cv2
from mxnet import image
from mxnet import autograd

2. 导入数据

我使用cifar10这个数据集,使用gluon自带的模块下载到本地并且为了配合后面的网络,我将大小调整到224*224

def transform(data, label):
data = image.imresize(data, 224, 224)
return data.astype('float32'), label.astype('float32')
cifar10_train = gluon.data.vision.CIFAR10(root='./',train=True, transform=transform)
cifar10_test = gluon.data.vision.CIFAR10(root='./',train=False, transform=transform)
batch_size = 64
train_data = gluon.data.DataLoader(cifar10_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(cifar10_test, batch_size, shuffle=False)

3. 加载预训练模型

gluon提供的很多预训练模型,我选择一个简单的模型AlexNet

首先下载AlexNet模型和模型参数

使用下面的代码会获取AlexNet的模型并且加载预训练好的模型参数,但是鉴于网络的原因,我提前下好了

alexnet = mx.gluon.model_zoo.vision.alexnet(pretrained=True)#如果pretrained值为True,则会下载预训练参数,否则是空模型

获取模型并从本地加载参数

alexnet = mx.gluon.model_zoo.vision.alexnet()
alexnet.load_params('alexnet-44335d1f.params',ctx=mx.gpu())

看下AlexNet网络结构,发现分为两部分,features,classifier,而features正好是需要的

print(alexnet)
AlexNet(
(features): HybridSequential(
(0): Conv2D(64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(2): Conv2D(192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(4): Conv2D(384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(8): Flatten
)
(classifier): HybridSequential(
(0): Dense(4096, Activation(relu))
(1): Dropout(p = 0.5)
(2): Dense(4096, Activation(relu))
(3): Dropout(p = 0.5)
(4): Dense(1000, linear)
)
)

4. 组合新的网络

截取想要的features,并且固定参数。这样防止训练的时候把预训练好的参数给搞坏了

featuresnet = alexnet.features
for _, w in featuresnet.collect_params().items():
w.grad_req = 'null'

自己定义后面的网络,因为数据集是10类,就把最后的输出从1000改成了10。

def Classifier():
net = nn.HybridSequential()
net.add(nn.Dense(4096, activation="relu"))
net.add(nn.Dropout(.5))
net.add(nn.Dense(4096, activation="relu"))
net.add(nn.Dropout(.5))
net.add(nn.Dense(10))
return net

接着需要把两部分组合起来,并且对第二部分机进行初始化

net = nn.HybridSequential()
with net.name_scope():
net.add(featuresnet)
net.add(Classifier())
net[1].collect_params().initialize(init=mx.init.Xavier(),ctx=mx.gpu())
net.hybridize()

5. 训练

最后就是训练了,看看效果如何

#定义准确率函数
def accuracy(output, label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iterator, net, ctx=mx.gpu()):
acc = 0.
for data, label in data_iterator:
data = data.transpose([0,3,1,2])
data = data/255
output = net(data.as_in_context(ctx))
acc += accuracy(output, label.as_in_context(ctx))
return acc / len(data_iterator)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(
net.collect_params(), 'sgd', {'learning_rate': 0.01})
for epoch in range(1):
train_loss = 0.
train_acc = 0.
test_acc = 0.
for data, label in train_data:
label = label.as_in_context(mx.gpu())
data = data.transpose([0,3,1,2])
data = data/255
with autograd.record():
output = net(data.as_in_context(mx.gpu()))
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label)
test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data),
train_acc/len(train_data),test_acc))
Epoch 0. Loss: 1.249197, Train acc 0.558764, Test acc 0.696756

使用MxNet新接口Gluon提供的预训练模型进行微调的更多相关文章

  1. MXNet的新接口Gluon

    为什么要开发Gluon的接口 在MXNet中我们可以通过Sybmol模块来定义神经网络,并组通过Module模块提供的一些上层API来简化整个训练过程.那MXNet为什么还要重新开发一套Python的 ...

  2. MxNet新前端Gluon模型转换到Symbol

    1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...

  3. Paddle预训练模型应用工具PaddleHub

    Paddle预训练模型应用工具PaddleHub 本文主要介绍如何使用飞桨预训练模型管理工具PaddleHub,快速体验模型以及实现迁移学习.建议使用GPU环境运行相关程序,可以在启动环境时,如下图所 ...

  4. 预训练模型时代:告别finetune, 拥抱adapter

    NLP论文解读 原创•作者 |FLIPPED 研究背景 随着计算算力的不断增加,以transformer为主要架构的预训练模型进入了百花齐放的时代.BERT.RoBERTa等模型的提出为NLP相关问题 ...

  5. 【转载】最强NLP预训练模型!谷歌BERT横扫11项NLP任务记录

    本文介绍了一种新的语言表征模型 BERT--来自 Transformer 的双向编码器表征.与最近的语言表征模型不同,BERT 旨在基于所有层的左.右语境来预训练深度双向表征.BERT 是首个在大批句 ...

  6. PyTorch-网络的创建,预训练模型的加载

    本文是PyTorch使用过程中的的一些总结,有以下内容: 构建网络模型的方法 网络层的遍历 各层参数的遍历 模型的保存与加载 从预训练模型为网络参数赋值 主要涉及到以下函数的使用 add_module ...

  7. 【翻译】OpenVINO Pre-Trained 预训练模型介绍

    OpenVINO 系列软件包预训练模型介绍 本文翻译自 Intel OpenVINO 的  "Overview of OpenVINO Toolkit Pre-Trained Models& ...

  8. dropzonejs中文翻译手册 DropzoneJS是一个提供文件拖拽上传并且提供图片预览的开源类库.

    http://wxb.github.io/dropzonejs.com.zh-CN/dropzonezh-CN/ 由于项目需要,完成一个web的图片拖拽上传,也就顺便学习和了解了一下前端的比较新的技术 ...

  9. 微信小程序语音识别服务搭建全过程解析(https api开放,支持新接口mp3录音、老接口silk录音)

    silk v3(或新录音接口mp3)录音转olami语音识别和语义处理的api服务(ubuntu16.04服务器上实现) 重要的写在前面 重要事项一: 所有相关更新,我优先更新到我个人博客中,其它地方 ...

随机推荐

  1. 201521123069 《Java程序设计》 第4周学习总结

    1. 本周学习总结 1.1 尝试使用思维导图总结有关继承的知识点. 1.2 使用常规方法总结其他上课内容. (1)文档注释(类注释.方法注释.属性注释.通用注释.) (2)多态性.instanceof ...

  2. 201521123113 《Java程序设计》第4周学习总结

    1.本周学习总结 1.1 尝试使用思维导图总结有关继承的知识点. 1.2 使用常规方法总结其他上课内容. 设计类的技巧:类名和方法名要能够体现他们的职责,类名首字母要大写 如何识别一个类 方法.属性的 ...

  3. 201521123088 《Java程序设计》第1周学习总结

    第1周学习总结 1.本周学习总结本周我们正式开始了对一门新的编程语言java的学习.Java是一门面向对象编程语言,不仅吸收了C++语言的各种优点,还摒弃了C++里难以理解的多继承.指针等概念,因此J ...

  4. evak购物车--课程设计(201521123037邱晓娴)

    1. 团队课程设计博客链接 团队博客 2. 个人负责模块或任务说明 1.Java (1)编写用户类Users (2)编写DBConnection类,连接数据库 (3)编写GoodsDAO类,从数据库中 ...

  5. JAVA课程设计 猜数游戏 团队

    团队名称,成员介绍 名称: 猜数游戏 成员: 网络1514 201521123086 周颖强 网络1514 201521123087蒋勃超 项目git地址 git.oschina.net/jbc113 ...

  6. 201521123106 《Java程序设计》第11周学习总结

    1. 本周学习总结 1.1 以你喜欢的方式(思维导图或其他)归纳总结多线程相关内容. 2. 书面作业 本次PTA作业题集多线程 互斥访问与同步访问 完成题集4-4(互斥访问)与4-5(同步访问) 1. ...

  7. 关于SVM数学细节逻辑的个人理解(三) :SMO算法理解

    第三部分:SMO算法的个人理解 接下来的这部分我觉得是最难理解的?而且计算也是最难得,就是SMO算法. SMO算法就是帮助我们求解: s.t.   这个优化问题的. 虽然这个优化问题只剩下了α这一个变 ...

  8. java:java构造器和java方法的区别

    构造函数(构造器)是一种特殊的函数.其主要功能是用来在创建对象时初始化对象, 即为对象成员变量赋初始值,总与new运算符一起使用在创建对象的语句中.构造函数与类名相同,可重载多个不同的构造函数.在JA ...

  9. 银河麒麟操作系统打开VMware报vmmon无法编译

    使用银河麒麟操作系统打开VMware可能会报vmmon无法编译 这个时候... 将/usr/src/linux-headers-xxx/include/miscdevice.h第71行void改为in ...

  10. SpringMVC第六篇【校验、统一处理异常】

    Validation 在我们的Struts2中,我们是继承ActionSupport来实现校验的-它有两种方式来实现校验的功能 手写代码 XML配置 这两种方式也是可以特定处理方法或者整个Action ...