MXNet的新接口Gluon
为什么要开发Gluon的接口
在MXNet中我们可以通过Sybmol
模块来定义神经网络,并组通过Module
模块提供的一些上层API来简化整个训练过程。那MXNet为什么还要重新开发一套Python的API呢,是否是重复造轮子呢?答案是否定的,Gluon主要是学习了Keras、Pytorch等框架的优点,支持动态图(Imperative)编程,更加灵活且方便调试。而原来MXNet基于Symbol来构建网络的方法是像TF、Caffe2一样静态图的编程方法。同时Gluon也继续了MXNet在静态图上的一些优化,比如节省显存,并行效率高等,运行起来比Pytorch更快。
更加简洁的接口
我们先来看一下用Gluon的接口,如果创建并组训练一个神经网络的,我们以mnist数据集为例:
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon
import mxnet.gluon.nn as nn
数据的读取
首先我们利用Gluon的data模块来读取mnist数据集
def transform(data, label):
return data.astype('float32') / 255, label.astype('float32')
minist_train_dataset = gluon.data.vision.MNIST(train=True, transform=transform)
minist_test_dataset = gluon.data.vision.MNIST(train=False, transform=transform)
batch_size = 64
train_data = gluon.data.DataLoader(dataset=minist_train_dataset, shuffle=True, batch_size=batch_size)
test_data = gluon.data.DataLoader(dataset=minist_train_dataset, shuffle=False, batch_size=batch_size)
num_examples = len(train_data)
print(num_examples)
训练模型
这里我们使用Gluon来定义一个LeNet
# Step1 定义模型
lenet = nn.Sequential()
with lenet.name_scope():
lenet.add(nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
lenet.add(nn.MaxPool2D(pool_size=2, strides=2))
lenet.add(nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
lenet.add(nn.MaxPool2D(pool_size=2, strides=2))
lenet.add(nn.Flatten())
lenet.add(nn.Dense(128, activation='relu'))
lenet.add(nn.Dense(10))
# Step2 初始化模型参数
lenet.initialize(ctx=mx.gpu())
# Step3 定义loss
softmax_loss = gluon.loss.SoftmaxCrossEntropyLoss()
# Step4 优化
trainer = gluon.Trainer(lenet.collect_params(), 'sgd', {'learning_rate': 0.5})
def accuracy(output, label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(net, data_iter):
acc = 0
for data, label in data_iter:
data = data.transpose((0,3,1,2))
data = data.as_in_context(mx.gpu())
label = label.as_in_context(mx.gpu())
output = net(data)
acc += accuracy(output, label)
return acc / len(data_iter)
import mxnet.autograd as ag
epochs = 5
for e in range(epochs):
total_loss = 0
for data, label in train_data:
data = data.transpose((0,3,1,2))
data = data.as_in_context(mx.gpu())
label = label.as_in_context(mx.gpu())
with ag.record():
output = lenet(data)
loss = softmax_loss(output, label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.mean(loss).asscalar()
print("Epoch %d, test accuracy: %f, average loss: %f" % (e, evaluate_accuracy(lenet, test_data), total_loss/num_examples))
背后的英雄 nn.Block
我们前面使用了nn.Sequential
来定义一个模型,但是没有仔细介绍它,它其实是nn.Block
的一个简单的形式。而nn.Block
是一个一般化的部件。整个神经网络可以是一个nn.Block
,单个层也是一个nn.Block
。我们可以(近似)无限地嵌套nn.Block
来构建新的nn.Block
。nn.Block
主要提供3个方向的功能:
- 存储参数
- 描述
forward
如何执行 - 自动求导
所以nn.Sequential
是一个nn.Block
的容器,它通过add
来添加nn.Block
。它自动生成forward()
函数。一个简单实现看起来如下:
class Sequential(nn.Block):
def __init__(self, **kwargs):
super(Sequential, self).__init__(**kwargs)
def add(self, block):
self._children.append(block)
def forward(self, x):
for block in self._children:
x = block(x)
return x
知道了nn.Block
里的魔法后,我们就可以自定我们自己的nn.Block
了,来实现不同的深度学习应用可能遇到的一些新的层。
在nn.Block
中参数都是以一种Parameter
的对象,通过这个对象的data()
和grad()
来访问对应的数据和梯度。
my_param = gluon.Parameter('my_params', shape=(3,3))
my_param.initialize()
(my_param.data(), my_param.grad())
每个nn.Block
里都有一个类型为ParameterDict
类型的成员变量params
来保存所有这个层的参数。它其际上是一个名称到参数映射的字典。
pd = gluon.ParameterDict(prefix='custom_layer_name')
pd.get('custom_layer_param1', shape=(3,3))
pd
自义我们自己的全连接层
当我们要实现的功能在Gluon.nn模块中找不到对应的实现时,我们可以创建自己的层,它实际也就是一个nn.Block
对象。要自定义一个nn.Block
以,只需要继承nn.Block
,如果该层需要参数,则在初始化函数中做好对应参数的初始化(实际只是分配的形状),然后再实现一个forward()
函数来描述计算过程。
class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
with self.name_scope():
self.weight = self.params.get(
'weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,))
def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear)
审视模型的参数
我们将从下面三个方面来详细讲解如何操作gluon定义的模型的参数。
- 初始化
- 读取参数
- 参数的保存与加载
从上面我们们在mnist训练一个模型的步骤中可以看出,当我们定义好模型后,第一步就是需要调用initialize()
对模型进行参数初始化。
def get_net():
net = nn.Sequential()
with net.name_scope():
net.add(nn.Dense(4, activation='relu'))
net.add(nn.Dense(2))
return net
net = get_net()
net.initialize()
我们一直使用默认的initialize
来初始化权重。实际上我们可以指定其他初始化的方法,mxnet.initializer
模块中提供了大量的初始化权重的方法。比如非常流行的Xavier
方法。
#net.initialize(init=mx.init.Xavier())
x = nd.random.normal(shape=(3,4))
net(x)
我们可以weight
和bias
来访问Dense的参数,它们是Parameter
对象。
w = net[0].weight
b = net[0].bias
print('weight:', w.data())
print('weight gradient', w.grad())
print('bias:', b.data())
print('bias gradient', b.grad())
我们也可以通过collect_params
来访问Block
里面所有的参数(这个会包括所有的子Block)。它会返回一个名字到对应Parameter
的dict。既可以用正常[]
来访问参数,也可以用get()
,它不需要填写名字的前缀。
params = net.collect_params()
print(params)
print(params['sequential18_dense0_weight'].data())
print(params.get('dense0_bias').data()) #不需要名字的前缀
延后的初始化
如果我们仔细分析过整个网络的初始化,我们会有发现,当我们没有给网络真正的输入数据时,网络中的很多参数是无法确认形状的。
net = get_net()
net.collect_params()
net.initialize()
net.collect_params()
我们注意到参数中的weight
的形状的第二维都是0, 也就是说还没有确认。那我们可以肯定的是这些参数肯定是还没有分配内存的。
net(x)
net.collect_params()
当我们给这个网络一个输入数据后,网络中的数据参数的形状就固定下来了。而这个时候,如果我们给这个网络一个不同shape的输入数据,那运行中就会出现崩溃的问题。
模型参数的保存与加载
gluon.Sequential
模块提供了save
和load
接口来方便我们对一个网络的参数进行保存与加载。
filename = "mynet.params"
net.save_params(filename)
net2 = get_net()
net2.load_params(filename, mx.cpu())
Hybridize
从上面我们使用gluon来训练mnist,可以看出,我们使用的是一种命令式的编程风格。大部分的深度学习框架只在命令式与符号式间二选一。那我们能不能拿到两种泛式全部的优点呢,事实上这一点可以做到。在MXNet的GluonAPI中,我们可以使用HybridBlock
或者HybridSequential
来构建网络。默认他们跟Block
和Sequential
一样是命令式的。但当我们调用.hybridize()
后,系统会转撚成符号式来执行。
def get_net():
net = nn.HybridSequential()
with net.name_scope():
net.add(
nn.Dense(256, activation="relu"),
nn.Dense(128, activation="relu"),
nn.Dense(2)
)
net.initialize()
return net
x = nd.random.normal(shape=(1, 512))
net = get_net()
net(x)
net.hybridize()
net(x)
注意到只有继承自HybridBlock的层才会被优化。HybridSequential和Gluon提供的层都是它的子类。如果一个层只是继承自Block,那么我们将跳过优化。我们可以将符号化的模型的定义保存下来,在其他语言API中加载。
x = mx.sym.var('data')
y = net(x)
print(y.tojson())
可以看出,对于HybridBlock
的模块,既可以把NDArray作为输入,也可以把Symbol
对象作为输入。当以Symbol
作为输出时,它的结果就是一个Symbol
对象。
MXNet的新接口Gluon的更多相关文章
- 使用MxNet新接口Gluon提供的预训练模型进行微调
1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...
- AMD正式公布第七代桌面级APU AM4新接口
导读 本月5日,AMD正式公布了入门级的第七代桌面级APU为Bristol Ridge,在性能和能效方面较上一代产品拥有显著提升.AMD同时确认Zen处理器和新APU(Bristol Ridge)都将 ...
- UWP: 体验应用内购新接口——StoreContext类
Windows 1607 版本(内部版本 14393)之后,微软在 SDK 添加了一些与应用商店相关的新接口,像应用试用与购买.应用内购等.这些接口相对于原来的接口要方便很多.就拿应用内购来说,以前的 ...
- 微信小程序语音识别服务搭建全过程解析(https api开放,支持新接口mp3录音、老接口silk录音)
silk v3(或新录音接口mp3)录音转olami语音识别和语义处理的api服务(ubuntu16.04服务器上实现) 重要的写在前面 重要事项一: 所有相关更新,我优先更新到我个人博客中,其它地方 ...
- 新接口注册LED字符驱动设备
#include <linux/init.h> // __init __exit #include <linux/module.h> // module_init module ...
- MxNet新前端Gluon模型转换到Symbol
1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...
- MXNet源码分析 | Gluon接口分布式训练流程
本文主要基于MXNet1.6.0版本,对Gluon接口的分布式训练过程进行简要分析. 众所周知,KVStore负责MXNet分布式训练过程中参数的同步,那么它究竟是如何应用在训练中的呢?下面我们将从G ...
- ios7新特性1-UI变化、UIKit动态行为支持与Text Kit新接口
iOS 7.0新特性1 iOS 7的UI经过了重新设计.另外,iOS7中引入了新的动画系统,便于创建2D和2.5D的游戏.多任务支持提升,点对点通讯以及其他重要的特征使iOS7相对于以往的SDK来说发 ...
- MXNET:权重衰减-gluon实现
构建数据集 # -*- coding: utf-8 -*- from mxnet import init from mxnet import ndarray as nd from mxnet.gluo ...
随机推荐
- node重新加载模块
delete require.cache[require.resolve('module name')]; var my_module = require('module name');
- [leetcode]149. Max Points on a Line多点共线
Given n points on a 2D plane, find the maximum number of points that lie on the same straight line. ...
- django xadmin拓展User模型
django提供四种拓展模型的方法: 1.代理模型 2.Profile拓展模型User 3.AbstractBaseUser拓展模型User 4.AbstractUser拓展模型 之前想通过第四种方法 ...
- Vue的从入门到放弃
此贴仅记录vue学习路程中遇见的大大小小,形形色色的问题 1. vue自动打开浏览器配置: 当使用vue 脚手架搭建项目后启动npm run dev,会出现 但是不会自动打开浏览器的,这时候去con ...
- RestSharp发送请求得到Json数据
NUGET安装:RestSharp code: public string Post(string url, string content) { string contentType = " ...
- docker常用操作备忘
一.docker安装 参考资料:阿里云镜像加速1. 安装/升级Docker客户端 curl -fsSL https://get.docker.com | bash -s docker --mirror ...
- oracle分区表的使用和查询
本文参考了 https://blog.csdn.net/mzglzzc/article/details/46300645 一 创建和使用分区表 1.范围分区(RANGE)范围分区将数据基于范围映射到 ...
- 让Spring Boot项目启动时可以根据自定义配置决定初始化哪些Bean
让Spring Boot项目启动时可以根据自定义配置决定初始化哪些Bean 问题描述 实现思路 思路一 [不符合要求] 思路二[满足要求] 思路三[未试验] 问题描述 目前我工作环境下,后端主要的框架 ...
- [swarthmore cs75] Lab 1 — OCaml Tree Programming
课程回顾 Swarthmore学院16年开的编译系统课,总共10次大作业.本随笔记录了相关的课堂笔记以及第2大次作业. 比较两个lists的逻辑: let rec cmp l ll = match ( ...
- 回顾django内容
回顾: 1 HTTP协议:(重点) -请求 -请求首行 -GET /index HTTP/1.1 -请求头部(在django框架中,可以从META中取出来) -key:value------>\ ...