L2范数惩罚项,高维线性回归
%matplotlib inline
import mxnet
from mxnet import nd,autograd
from mxnet import gluon,init
from mxnet.gluon import data as gdata,loss as gloss,nn
import gluonbook as gb n_train, n_test, num_inputs = 20,100,200 true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05 features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features,true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape) train_feature = features[:n_train,:]
test_feature = features[n_train:,:]
train_labels = labels[:n_train]
test_labels = labels[n_train:] #print(features,train_feature,test_feature) # 初始化模型参数
def init_params():
w = nd.random.normal(scale=1, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))
w.attach_grad()
b.attach_grad()
return [w,b] # 定义,训练,测试 batch_size = 1
num_epochs = 100
lr = 0.03 train_iter = gdata.DataLoader(gdata.ArrayDataset(train_feature,train_labels),batch_size=batch_size,shuffle=True) # 定义网络
def linreg(X, w, b):
return nd.dot(X,w) + b # 损失函数
def squared_loss(y_hat, y):
"""Squared loss."""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 # L2 范数惩罚
def l2_penalty(w):
return (w**2).sum() / 2 def sgd(params, lr, batch_size):
for param in params:
param[:] = param - lr * param.grad / batch_size def fit_and_plot(lambd):
w, b = init_params()
train_ls, test_ls = [], []
for _ in range(num_epochs):
for X, y in train_iter:
with autograd.record():
# 添加了 L2 范数惩罚项。
l = squared_loss(linreg(X, w, b), y) + lambd * l2_penalty(w)
l.backward()
sgd([w, b], lr, batch_size)
train_ls.append(squared_loss(linreg(train_feature, w, b),
train_labels).mean().asscalar())
test_ls.append(squared_loss(linreg(test_feature, w, b),
test_labels).mean().asscalar())
gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
range(1, num_epochs + 1), test_ls, ['train', 'test'])
print('L2 norm of w:', w.norm().asscalar())
fit_and_plot(0)
fit_and_plot(3)
训练集太少,容易出现过拟合,即训练集loss远小于测试集loss,解决方案,权重衰减——(L2范数正则化)
例如线性回归:
loss(w1,w2,b) = 1/n * sum(x1w1 + x2w2 + b - y)^2 /2 ,平方损失函数。
权重参数 w = [w1,w2],
新损失函数 loss(w1,w2,b) += lambd / 2n *||w||^2
迭代方程:
L2范数惩罚项,高维线性回归的更多相关文章
- 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...
- 机器学习中的范数规则化 L0、L1与L2范数 核范数与规则项参数选择
http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...
- 机器学习中的范数规则化之 L0、L1与L2范数、核范数与规则项参数选择
装载自:https://blog.csdn.net/u012467880/article/details/52852242 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理 ...
- 《机器学习实战》学习笔记第八章 —— 线性回归、L1、L2范数正则项
相关笔记: 吴恩达机器学习笔记(一) —— 线性回归 吴恩达机器学习笔记(三) —— Regularization正则化 ( 问题遗留: 小可只知道引入正则项能降低参数的取值,但为什么能保证 Σθ2 ...
- deep learning (五)线性回归中L2范数的应用
cost function 加一个正则项的原因是防止产生过拟合现象.正则项有L1,L2 等范数,我看过讲的最好的是这个博客上的:机器学习中的范数规则化之(一)L0.L1与L2范数.看完应该就答题明白了 ...
- paper 126:[转载] 机器学习中的范数规则化之(一)L0、L1与L2范数
机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...
- 机器学习中的范数规则化之(一)L0、L1与L2范数(转)
http://blog.csdn.net/zouxy09/article/details/24971995 机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http: ...
- L0、L1与L2范数、核范数(转)
L0.L1与L2范数.核范数 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化.我们先简单的来理解下常用的L0.L1.L2和核范数规则化.最后聊下规则化项参数的选择问题.这里因为篇幅比较庞大 ...
- 机器学习中的范数规则化之(一)L0、L1与L2范数 非常好,必看
机器学习中的范数规则化之(一)L0.L1与L2范数 zouxy09@qq.com http://blog.csdn.net/zouxy09 今天我们聊聊机器学习中出现的非常频繁的问题:过拟合与规则化. ...
随机推荐
- 引导篇之web结构组件
web结构组件有如下几种: 代理 HTTP代理服务器,是Web安全.应用集成以及性能优化的重要组成模块.代理位于客户端和服务器之间,接收所有客户端的HTTP请求,并将这些请求转发给服务器(可能会对请求 ...
- DP Intro - Tree DP Examples
因为上次比赛sb地把一道树形dp当费用流做了,受了点刺激,用一天时间稍微搞一下树形DP,今后再好好搞一下) 基于背包原理的树形DP poj 1947 Rebuilding Roads 题意:给你一棵树 ...
- 里氏替换原则(Liskov Substitution Principle) LSP
using System; using System.Collections.Generic; using System.Text; namespace LiskovSubstitutionPrinc ...
- PyCharm常见用法
1.设置python运行版本: File-->Setting-->Project-->Project Interpreter 2.代码批量左移/右移一个tab: 鼠标选中行,Tab右 ...
- 牛客网Java刷题知识点之拥塞发生的主要原因、TCP拥塞控制、TCP流量控制、TCP拥塞控制的四大过程(慢启动、拥塞避免、快速重传、快速恢复)
不多说,直接上干货! 福利 => 每天都推送 欢迎大家,关注微信扫码并加入我的4个微信公众号: 大数据躺过的坑 Java从入门到架构师 人工智能躺过的坑 ...
- js判断文件是否存在的方法
在做电力监控项目的时候,有一个需求就是左右布局的框架,点击左边的图形文件地址,然后去文件夹中找到文件,再在右边出现对应的图形文件,但是有些文件可能是配置的时候有问题,找不到文件,所以js需要判断,以下 ...
- opensuse12.3 桌面设置备忘录
工作空间外观 窗口装饰 ghost deco2.2 光标主题 ringG 桌面主题 ghost 欢迎屏幕 login-scan-splash-cg 应用程序外观 风格 部件样式 Oxygen 颜色 g ...
- 关于GitHub在VS中出现“已经存在master版本,无法……”的错误解决方案
引用:http://www.cnblogs.com/SmallZL/p/3637613.html(这篇已经很详细说明如何使用Vs+GitHub),我这里做补充: VS2013已经集成了Git一部分控件 ...
- php二维数组排序方法(array_multisort usort)
一维数组排序可以使用asort.ksort等一些方法进程排序,相对来说比较简单.二维数组的排序怎么实现呢?使用array_multisort和usort可以实现 例如像下面的数组: $users = ...
- 深入理解JavaScript系列(28):设计模式之工厂模式
介绍 与创建型模式类似,工厂模式创建对象(视为工厂里的产品)时无需指定创建对象的具体类. 工厂模式定义一个用于创建对象的接口,这个接口由子类决定实例化哪一个类.该模式使一个类的实例化延迟到了子类.而子 ...