构建数据集

# -*- coding: utf-8 -*-
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import loss as gloss
import gb n_train = 20
n_test = 100 num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autograd
from mxnet.gluon import data as gdata batch_size = 1
num_epochs = 10
learning_rate = 0.003 train_iter = gdata.DataLoader(gdata.ArrayDataset(
train_features, train_labels), batch_size, shuffle=True)
loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluon
from mxnet.gluon import nn def fit_and_plot(weight_decay):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
# 对权重参数做 L2 范数正则化,即权重衰减。
trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {
'learning_rate': learning_rate, 'wd': weight_decay})
# 不对偏差参数做 L2 范数正则化。
trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {
'learning_rate': learning_rate})
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
# 对两个 Trainer 实例分别调用 step 函数。
trainer_w.step(batch_size)
trainer_b.step(batch_size)
train_ls.append(loss(net(train_features),
train_labels).mean().asscalar())
test_ls.append(loss(net(test_features),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()
print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

MXNET:权重衰减-gluon实现的更多相关文章

  1. MXNET:权重衰减

    权重衰减是应对过拟合问题的常用方法. \(L_2\)范数正则化 在深度学习中,我们常使用L2范数正则化,也就是在模型原先损失函数基础上添加L2范数惩罚项,从而得到训练所需要最小化的函数. L2范数惩罚 ...

  2. 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)

    无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...

  3. 权重衰减(weight decay)与学习率衰减(learning rate decay)

    本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...

  4. 从头学pytorch(六):权重衰减

    深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...

  5. MxNet新前端Gluon模型转换到Symbol

    1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...

  6. 使用MxNet新接口Gluon提供的预训练模型进行微调

    1. 导入各种包 from mxnet import gluon import mxnet as mx from mxnet.gluon import nn from mxnet import nda ...

  7. MXNET:丢弃法

    除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...

  8. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  9. mxnet深度学习实战学习笔记-9-目标检测

    1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...

随机推荐

  1. hdu 2063 过山车【匈牙利算法】(经典)

    <题目链接> RPG girls今天和大家一起去游乐场玩,终于可以坐上梦寐以求的过山车了.可是,过山车的每一排只有两个座位,而且还有条不成文的规矩,就是每个女生必须找个个男生做partne ...

  2. QT学习之windows下安装配置PyQt5

    windows下安装配置PyQt5 目录 为什么要学习QT 命令行安装PyQt5以及PyQt5-tools 配置QtDesigner.PyUIC及PyRcc 为什么要学习QT python下与界面开发 ...

  3. Nmap的详细使用

    Nmap的详细使用 介绍常用参数选项主机发现端口扫描服务和版本探测操作系统探测性能优化防火墙/IDS 躲避和哄骗输出 (一)介绍 Nmap — 网络探测工具和安全/端口扫描器. Nmap (“Netw ...

  4. Cube Stack

    Cube Stack 有一点lazy思想,设三个数组cnt代表它以下的有多少个元素(直到栈底),top[x]代表x所在栈的栈顶元素,dad[x]代表x所在栈的栈底元素,先寻找父亲,然后递归更新累加cn ...

  5. 算法进阶面试题06——实现LFU缓存算法、计算带括号的公式、介绍和实现跳表结构

    接着第四课的内容,主要讲LFU.表达式计算和跳表 第一题 上一题实现了LRU缓存算法,LFU也是一个著名的缓存算法 自行了解之后实现LFU中的set 和 get 要求:两个方法的时间复杂度都为O(1) ...

  6. 仙剑奇侠传 游戏 开发 教程 Xianjian qixia development Game development tutorial

    仙剑奇侠传 开发  游戏 开发 教程 Xianjian qixia development Game development tutorial 作者:韩梦飞沙 Author:han_meng_fei_ ...

  7. HDU.1529.Cashier Employment(差分约束 最长路SPFA)

    题目链接 \(Description\) 给定一天24h 每小时需要的员工数量Ri,有n个员工,已知每个员工开始工作的时间ti(ti∈[0,23]),每个员工会连续工作8h. 问能否满足一天的需求.若 ...

  8. BZOJ.1022.[SHOI2008]小约翰的游戏John(博弈论 Anti-Nim)

    题目链接 Anti-Nim游戏: 先手必胜当且仅当: 1.所有堆的石子数为1,且异或和为0 2.至少有一堆石子数>1,且异或和不为0 简要证明: 对于1:若异或和为1,则有奇数堆:异或和为0,则 ...

  9. 浅表拷贝vs深度拷贝

    浅表复制,只是创建所有的值类型,所有的引用类型还是会指向被复制的对象的引用. 故,当被复制的对象的引用类型发生改变的同事,复制的对象相应的 引用类型的值也是会发生改变的. 所以事件字段也是一个引用类型 ...

  10. 不涉及框架纯java实现将图片裁成圆形

    package com.wtsrui.utils;import java.awt.Color;  import sun.misc.BASE64Encoder;import java.awt.Graph ...