MXNET:权重衰减
权重衰减是应对过拟合问题的常用方法。
\(L_2\)范数正则化
在深度学习中,我们常使用L2范数正则化,也就是在模型原先损失函数基础上添加L2范数惩罚项,从而得到训练所需要最小化的函数。
L2范数惩罚项指的是模型权重参数每个元素的平方和与一个超参数的乘积。如:\(w_1\),\(w_2\)是权重参数,b是偏差参数,带\(L_2\)范数惩罚项的新损失函数为:
\]
\(\lambda\)调节了惩罚项的比重。
有了\(L_2\)范数后,在随机梯度下降中,以单层神经网络为例,权重的迭代公式变更为:
\]
\]
\(\eta\)为学习率,\(\mathcal{B}\)为样本数目,可见:\(L_2\)范数正则化令权重的每一步更新分别添加了\(-\lambda w_1\)和\(\lambda w_2\)。这就是\(L_2\)范数正则化被称为权重衰减(weight decay)的原因。
在实际中,我们有时也在惩罚项中添加偏差元素的平方和。
假设神经网络中某一个神经元的输入是\(x_1,x_2\),使用激活函数\(\phi\)并输出\(\phi(x_1w_1+x_2w_2+b)\)。假设激活函数\(\phi\)是ReLU、tanh或sigmoid,如果\(w_1,w_2,b\)都非常接近0,那么输出也接近0。也就是说,这个神经元的作用比较小,甚至就像是令神经网络少了一个神经元一样。这样便有效降低了模型的复杂度,降低了过拟合。
高维线性回归实验
我们通过高维线性回归为例来引入一个过拟合问题,并使用L2范数正则化来试着应对过拟合。
生成数据集
设数据样本特征的维度为p。对于训练数据集和测试数据集中特征为\(x_1,x_2,…,x_n\)的任一样本,我们使用如下的线性函数来生成该样本的标签:
\]
其中噪音项ϵ服从均值为0和标准差为0.1的正态分布。
为了较容易地观察过拟合,我们考虑高维线性回归问题,例如设维度p=200;同时,我们特意把训练数据集的样本数设低,例如20。
n_train = 20
n_test = 100
num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]
初始化模型参数
def init_params():
w = nd.random.normal(scale=1, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))
params = [w, b]
for param in params:
param.attach_grad()
return params
定义L2范数惩罚项
def l2_penalty(w):
return (w**2).sum() / 2
定义训练和测试
batch_size = 1
num_epochs = 10
lr = 0.003
net = gb.linreg
loss = gb.squared_loss
def fit_and_plot(lambd):
w, b = params = init_params()
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in gb.data_iter(batch_size, n_train, features, labels):
with autograd.record():
# 添加了 L2 范数惩罚项。
l = loss(net(X, w, b), y) + lambd * l2_penalty(w)
l.backward()
gb.sgd(params, lr, batch_size)
train_ls.append(loss(net(train_features, w, b),
train_labels).mean().asscalar())
test_ls.append(loss(net(test_features, w, b),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs+1), train_ls, 'epochs', 'loss',
range(1, num_epochs+1), test_ls, ['train', 'test'])
return 'w[:10]:', w[:10].T, 'b:', b
设置lambd=0,训练误差远小于测试(泛化)误差,这是典型的过拟合现象。
fit_and_plot(lambd=0)
# output
('w[:10]:',
[[ 0.30343655 -0.08110731 0.64756584 -1.51627898 0.16536537 0.42101485
0.41159022 0.8322348 -0.66477555 3.56285167]]
<NDArray 1x10 @cpu(0)>, 'b:',
[ 0.12521751]
<NDArray 1 @cpu(0)>)
使用正则化,过拟合现象得到一定程度上的缓解。然而依然没有学出较准确的模型参数。这主要是因为训练数据集的样本数相对维度来说太小。
fit_and_plot(lambd=5)
# output
('w[:10]:',
[[ 0.01602661 -0.00279179 0.03075662 -0.07356022 0.01006496 0.02420521
0.02145572 0.04235912 -0.03388886 0.17112994]]
<NDArray 1x10 @cpu(0)>, 'b:',
[ 0.08771407]
<NDArray 1 @cpu(0)>)
MXNET:权重衰减的更多相关文章
- MXNET:权重衰减-gluon实现
构建数据集 # -*- coding: utf-8 -*- from mxnet import init from mxnet import ndarray as nd from mxnet.gluo ...
- 调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)
无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正. learning_rate , weight_decay , momen ...
- 权重衰减(weight decay)与学习率衰减(learning rate decay)
本文链接:https://blog.csdn.net/program_developer/article/details/80867468“微信公众号” 1. 权重衰减(weight decay)L2 ...
- 从头学pytorch(六):权重衰减
深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...
- MXNET:丢弃法
除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...
- mxnet深度学习实战学习笔记-9-目标检测
1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...
- [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)
一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...
- 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...
- face recognition[variations of softmax][ArcFace]
本文来自<ArcFace: Additive Angular Margin Loss for Deep Face Recognition>,时间线为2018年1月.是洞见的作品,一作目前在 ...
随机推荐
- 不一样的go语言创世
在这之前,我是一名Java程序员,但最近我却已经好几个月没写Java代码了,因为我已经敲了好几个月的go,这是我连续最长的一段时间在写go.陆陆续续地算下来,也有快一年多的时间在与go打交道.期间 ...
- SQL try
BEGIN TRY -- Generate a constraint violation error. DELETE FROM Production.Product ; END TRY BEGIN C ...
- Mac电脑 阿里云ECS(ContentOS) Apache+vsftpd+nodejs+mongodb建站过程总结
简介:我这里采用的阿里云免费提供的6个月ECS服务器:制作了一个简单的爬虫程序:里面很多功能还么做:搜索里面功能回去的数据未做处理会崩溃(大家不要点搜索功能):地址:http://loldragon. ...
- python在使用MySQLdb模块时报Can't extract file(s) to egg cacheThe following error occurred while trying to extract file(s) to the Python eggcache的错误。
这个是因为python使用MySQLdb模块与mysql数据库交互时需要一个地方作为cache放置暂存的数据,但是调用python解释器的用户(常常是服务器如apache的www用户)对于cache所 ...
- 一个新的Android Studio 2.3.3可以在稳定的频道中使用。A new Android Studio 2.3.3 is available in the stable channel.
作者:韩梦飞沙 Author:han_meng_fei_sha 邮箱:313134555@qq.com E-mail: 313134555 @qq.com 一个新的Android Studio 2.3 ...
- CF 1131 E. String Multiplication
E. String Multiplication 题意 分析: 从后往前考虑字符串变成什么样子. 设$S_i = p_1 \cdot p_2 \dots p_{i}$,最后一定是$S_{n - 1} ...
- webpack 练习笔记
1, 创建项目 webpack mkdir webpack 2, 初始化项目 npm init 3, 全局安装webpack npm install webpack -g 4, 使用 // 创建静态文 ...
- LOJ6070 基因 分块+回文自动机
这个在翁文涛的论文里有讲到 大概的就是一个子串的回文自动机是原串回文自动机的子图 于是每隔$\sqrt n$重新跑一个$(k \times \sqrt n,n)$的回文自动机 记录回文串个数和位置 并 ...
- Mac下如何设置Eclipse默认浏览器为chrome
http://stackoverflow.com/questions/5322574/how-can-i-set-chrome-as-default-external-browser-in-eclip ...
- C#中使用 SendMessage 向非顶端窗体发送组合键
开门见山,不废话了, 直接举例说明一下: 比如发送ALT + F 以下是 用spy++截取的消息内容 <00001> 000310DC P WM_SYSKEYDOWN nVirtKey:V ...