构建数据集

# -*- coding: utf-8 -*-
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import loss as gloss
import gb n_train = 20
n_test = 100 num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autograd
from mxnet.gluon import data as gdata batch_size = 1
num_epochs = 10
learning_rate = 0.003 train_iter = gdata.DataLoader(gdata.ArrayDataset(
train_features, train_labels), batch_size, shuffle=True)
loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluon
from mxnet.gluon import nn def fit_and_plot(weight_decay):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
# 对权重参数做 L2 范数正则化,即权重衰减。
trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {
'learning_rate': learning_rate, 'wd': weight_decay})
# 不对偏差参数做 L2 范数正则化。
trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {
'learning_rate': learning_rate})
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
# 对两个 Trainer 实例分别调用 step 函数。
trainer_w.step(batch_size)
trainer_b.step(batch_size)
train_ls.append(loss(net(train_features),
train_labels).mean().asscalar())
test_ls.append(loss(net(test_features),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()
print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

MXNET:权重衰减-gluon实现的更多相关文章

  1. MXNET:权重衰减

    权重衰减是应对过拟合问题的常用方法. \(L_2\)范数正则化 在深度学习中,我们常使用L2范数正则化,也就是在模型原先损失函数基础上添加L2范数惩罚项,从而得到训练所需要最小化的函数. L2范数惩罚 ...

  2. 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)

    无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...

  3. 权重衰减(weight decay)与学习率衰减(learning rate decay)

    本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...

  4. 从头学pytorch(六):权重衰减

    深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...

  5. MxNet新前端Gluon模型转换到Symbol

    1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...

  6. 使用MxNet新接口Gluon提供的预训练模型进行微调

    1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...

  7. MXNET:丢弃法

    除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...

  8. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  9. mxnet深度学习实战学习笔记-9-目标检测

    1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...

随机推荐

  1. NumPy学习(让数据处理变简单)

    NumPy学习(一) NumPy数组创建 NumPy数组属性 NumPy数学算术与算数运算 NumPy数组创建 NumPy 中定义的最重要的对象是称为 ndarray 的 N 维数组类型. 它描述相同 ...

  2. P1939【模板】矩阵加速(数列)

    P1939[模板]矩阵加速(数列)难受就难受在a[i-3],这样的话让k=3就好了. #include<iostream> #include<cstdio> #include& ...

  3. JAVA "GMT+10" 和 "GMT+0010"

    可以使用 getAvailableIDs 方法来对所有受支持的时区 ID 进行迭代.可以选择受支持的 ID 来获得 TimeZone.如果想要的时区无法用受支持的 ID 之一表示,那么可以指定自定义时 ...

  4. Orleans逐步教程

    参考文档:https://dotnet.github.io/orleans/Tutorials/index.html 一.通过模板创建Orleans ①下载vs插件:https://marketpla ...

  5. mysql 5.7 离线安装

    删除原有的mariadbrpm -qa|grep mariadbrpm -e --nodeps mariadb-libs 安装rpm -ivh mysql-community-common-5.7.2 ...

  6. C++ - 定义无双引号的字符串宏

    在某些特殊场合下,我们可能需要定义一个字符串宏,但又不能用双引号 比如像这样 #define HELLO hello world 如果我们只是简单的展开HELLO,肯定会无法编译 std::cout ...

  7. hadoop from rookie to ninja - 1. Basic Architecture(基础架构)

    1. Daemons(守护进程) 新老架构 老的: Apache Hadoop 1.x (MRv1)   新的: Apache Hadoop 2.x (YARN)-Yet Another Resour ...

  8. MDX 查询原型

    本篇文章记录 SBS 中 MDX 查询原型,可以根据这些查询原型来解决实际项目中的问题. 1. 查询在 2004年1月2日 - 2004年3月1日之间购买过 Bikes 产品的用户. SELECT ( ...

  9. Nginx配置静态资源

    静态服务器 静态服务器概念非常简单:当用户请求静态资源时,把文件内容回复给用户. 但是,要把静态服务做到极致,需要考虑的方面非常多: 正确书写header:设置content-type.过期时间等 效 ...

  10. Java定时任务示例

    package com.my.timer; import java.util.Date; import java.util.TimerTask; public class myTask extends ...