概述

今天要说一下机器学习中大多数书籍第一个讲的(有的可能是KNN)模型-线性回归。说起线性回归,首先要介绍一下机器学习中的两个常见的问题:回归任务和分类任务。那什么是回归任务和分类任务呢?简单的来说,在监督学习中(也就是有标签的数据中),标签值为连续值时是回归任务,标志值是离散值时是分类任务。而线性回归模型就是处理回归任务的最基础的模型。

形式

在只有一个变量的情况下,线性回归可以用方程:y = ax+b 表示。而如果有多个变量,也就是n元线性回归的形式如下:

n元线性回归

在这里我们将截断b用θ0代替,同时数据集X也需要添加一列1用于与θ0相乘,表示+b。最后写成矩阵的形式就是θ的转置乘以x。其中如果数据集有n个特征,则θ就是n+1维的向量并非矩阵,其中包括截断b。

目的

线性回归的目的就是求解出合适的θ,在一元的情况下拟合出一条直线(多元情况下是平面或者曲面),可以近似的代表各个数据样本的标签值。所以最好的直线要距离各个样本点都很接近,而如何求出这条直线就是本篇文章重点要将的内容。

一元线性回归拟合数据

最小二乘法

求解线性回归模型的方法叫做最小二乘法,最小二乘法的核心就是保证所有数据偏差的平方和最小。它的具体形式是:

 

其中hθ(x^(i))代表每个样本通过我们模型的预测值,y^(i)代表每个样本标签的真实值,m为样本个数。因为模型预测值和真实值间存在误差e,可以写作:

 

根据中心极限定理,e^(i)是独立同分布的(IID),服从均值为0,方差为某定值σ的平方的正太分布。具体推导过程如下:

最小二乘法推导

求解最小二乘法:

我们要求得就是当θ取某个值时使J(θ)最小,求解最小二乘法的方法一般有两种方法:矩阵式和梯度下降法。

矩阵式求解:

当我们的数据集含有m个样本,每个样本有n个特征时,数据x可以写成m*(n+1)维的矩阵(+1是添加一列1,用于与截断b相乘),θ则为n+1维的列向量(+1是截断b),y为m维的列向量代表每m个样本结果的预测值。则矩阵式的推导如下所示:

 

因为X^tX为方阵,如果X^tX是可逆的,则参数θ得解析式可以写成:

 

如果X的特征数n不是很大,通常情况下X^tX是可以求逆的,但是如果n非常大,X^tX不可逆,则用梯度下降法求解参数θ。

梯度下降法(GD):

在一元函数中叫做求导,在多元函数中就叫做求梯度。梯度下降是一个最优化算法,通俗的来讲也就是沿着梯度下降的方向来求出一个函数的极小值。比如一元函数中,加速度减少的方向,总会找到一个点使速度达到最小。通常情况下,数据不可能完全符合我们的要求,所以很难用矩阵去求解,所以机器学习就应该用学习的方法,因此我们采用梯度下降,不断迭代,沿着梯度下降的方向来移动,求出极小值。梯度下降法包括批量梯度下降法和随机梯度下降法(SGD)以及二者的结合mini批量下降法(通常与SGD认为是同一种,常用于深度学习中)。

梯度下降法的一般过程如下:

1)初始化θ(随机)

2)求J(θ)对θ的偏导:

 

3)更新θ

 

其中α为学习率,调节学习率这个超参数也是建模中的一个重要内容。因为J(θ)是凸函数,所以GD求出的最优解是全局最优解。

批量梯度下降法是求出整个数据集的梯度,再去更新θ,所以每次迭代都是在求全局最优解。

 

而随机梯度下降法是求一个样本的梯度后就去跟新θ,所以每次迭代都是求局部最优解,但是总是朝着全局最优解前进,最后总会到达全局最优解。

 

其他线性回归模型:

在机器学习中,有时为了防止模型太复杂容易过拟合,通常会在模型上加入正则项,抑制模型复杂度,防止过拟合。在线性回归中有两种常用的正则,一个是L1正则,一个是L2正则,加入L1正则的称为Lasso回归,加入L2正则的成为Ridge回归也叫岭回归。

Lasso回归

岭回归

以下是个人所写的线性回归代码:

 

各个回归模型参数与结果对比以及与真实值的图像

 

待更新。

详细代码可参考GitHub:代码链接

机器学习-线性回归LinearRegression的更多相关文章

  1. 机器学习之路: python 线性回归LinearRegression, 随机参数回归SGDRegressor 预测波士顿房价

    python3学习使用api 线性回归,和 随机参数回归 git: https://github.com/linyi0604/MachineLearning from sklearn.datasets ...

  2. python机器学习---线性回归案例和KNN机器学习案例

    散点图和KNN预测 一丶案例引入 # 城市气候与海洋的关系研究 # 导包 import numpy as np import pandas as pd from pandas import Serie ...

  3. 机器学习03-sklearn.LinearRegression 源码学习

    在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现.但是在看到源码的一瞬间突然有种 ...

  4. 机器学习|线性回归算法详解 (Python 语言描述)

    原文地址 ? 传送门 线性回归 线性回归是一种较为简单,但十分重要的机器学习方法.掌握线性的原理及求解方法,是深入了解线性回归的基本要求.除此之外,线性回归也是监督学习回归部分的基石. 线性回归介绍 ...

  5. 机器学习---线性回归(Machine Learning Linear Regression)

    线性回归是机器学习中最基础的模型,掌握了线性回归模型,有利于以后更容易地理解其它复杂的模型. 线性回归看似简单,但是其中包含了线性代数,微积分,概率等诸多方面的知识.让我们先从最简单的形式开始. 一元 ...

  6. 机器学习——线性回归-KNN-决策树(实例)

    导入类库 import numpy as np import pandas as pd from sklearn.linear_model import LinearRegression from s ...

  7. 吴裕雄 python 机器学习——线性回归模型

    import numpy as np from sklearn import datasets,linear_model from sklearn.model_selection import tra ...

  8. 机器学习之LinearRegression与Logistic Regression逻辑斯蒂回归(三)

    一 评价尺度 sklearn包含四种评价尺度 1 均方差(mean-squared-error) 2 平均绝对值误差(mean_absolute_error) 3 可释方差得分(explained_v ...

  9. 线性回归 - LinearRegression - 预测糖尿病 - 量化预测的质量

    线性回归是分析一个变量与另外一个或多个变量(自变量)之间,关系强度的方法. 线性回归的标志,如名称所暗示的那样,即自变量与结果变量之间的关系是线性的,也就是说变量关系可以连城一条直线. 模型评估:量化 ...

随机推荐

  1. SSH整合(一)

    一.ssh原始整合方式 不需要任何整合包,就是简单的将三个框架集合到一起 hibernate        导入jar包:            hibernate-release-5.0.7.Fin ...

  2. MySQL---视图、触发器

    一.视图 视图是一个虚拟表(非真实存在),其本质是[根据SQL语句获取动态的数据集,并为其命名],用户使用时只需使用[名称]即可获取结果集,并可以将其当作表来使用. SELECT * FROM ( S ...

  3. JavaSE环境下的shiro(源自腾讯课堂)

    Shiro作用: 认证(登录).授权(鉴权).加密(用户名/密码加密).会话管理(session).Web集成.缓存 apache官网可以下载 图一 图二 图三 图一 .二是配置文件内容,对于图三的: ...

  4. [Java]Java 9运行Spring Boot项目报错的解决办法

    简介 为了学习和尽快掌握 Java 9 的模块化(Module System)新特性,最近安装了 JDK 9,新建了一个 Spring Boot 进行尝试, 过程中遇到了一下报错问题,写下此文谨作为个 ...

  5. 如何快速生成数据库字典(thinkphp5.0)

    本教程将教你快速生成数据库字典 示例代码使用PHP框架:Thinkphp5.0 PHP代码: /** * 生成数据库字典html * 可直接另存为再copy到word文档中使用 * * @return ...

  6. Oracle_11g桌面版 中解决被锁定的scott 教学数据库的方法

    Oracle 11g中修改被锁定的用户:scott 在安装完Oracle10g和创建完oracle数据库之后,想用数据库自带的用户scott登录,看看连接是否成功. 在cmd命令中,用“sqlplus ...

  7. 帆软SQL报异常:多表连接的时候出现错误:未明确定义列

    我刚开始的代码: select dm_veh_jdcgz_mx.DAY_ID ,--日期 dm_veh_jdcgz_mx.GLBM ,--管理部门ID dm_veh_jdcgz_mx.SFZMHM , ...

  8. 【Keil】Keil5的安装和破...

    档案的话网上很多的,另外要看你开发的是哪种内核的芯片 如果是STC的,就安装C51 如果是STM的,就安装MDK 当然市面上有很多芯片的,我也没用过那么多种,这里也就不列举了 至于注册机,就是...恩 ...

  9. 踩坑留印,启动进程遇到报错:/proc/self/fd/9: 2: ulimit: bad number

    启动进程,遇到报错: /proc/self/fd/9: 2: ulimit: bad number 分析配置文件内容没有错误. 怀疑可能是文件格式问题,在IDE里面查看,果然是windows格式.ID ...

  10. 北京优步UBER司机B组最新奖励政策、高峰翻倍奖励、行程奖励、金牌司机奖励【每周更新】

    滴快车单单2.5倍,注册地址:http://www.udache.com/ 如何注册Uber司机(全国版最新最详细注册流程)/月入2万/不用抢单:http://www.cnblogs.com/mfry ...