sklearn中的LinearRegression

  • 函数原型:class sklearn.linear_model.LinearRegression(fit_intercept=True,normalize=False,copy_X=True,n_jobs=1)

  • fit_intercept:模型是否存在截距

  • normalize:模型是否对数据进行标准化(在回归之前,对X减去平均值再除以二范数),如果fit_intercept被设置为False时,该参数将忽略。

    该函数有属性:coef_可供查看模型训练后得到的估计系数,如果获取的估计系数太大,说明模型有可能过拟合。

    使用样例:

      >>>from sklearn import linear_model
    >>>clf = linear_model.LinearRegression()
    X = [[0,0],[1,1],[2,2]]
    y = [0,1,2]
    >>>clf.fit(X,y)
    >>>print(clf.coef_)
    [ 0.5 0.5]
    >>>print(clf.intercept_)
    1.11022302463e-16

源码分析

在github可以找到LinearRegression的源码:LinearRegression

  • 主要思想:sklearn.linear_model.LinearRegression求解线性回归方程参数时,首先判断训练集X是否是稀疏矩阵,如果是,就用Golub&Kanlan双对角线化过程方法来求解;否则调用C库中LAPACK中的用基于分治法的奇异值分解来求解。在sklearn中并不是使用梯度下降法求解线性回归,而是使用最小二乘法求解。

    sklearn.LinearRegression的fit()方法:

      if sp.issparse(X):#如果X是稀疏矩阵
    if y.ndim < 2:
    out = sparse_lsqr(X, y)
    self.coef_ = out[0]
    self._residues = out[3]
    else:
    # sparse_lstsq cannot handle y with shape (M, K)
    outs = Parallel(n_jobs=n_jobs_)(
    delayed(sparse_lsqr)(X, y[:, j].ravel())
    for j in range(y.shape[1]))
    self.coef_ = np.vstack(out[0] for out in outs)
    self._residues = np.vstack(out[3] for out in outs)
    else:
    self.coef_, self._residues, self.rank_, self.singular_ = \
    linalg.lstsq(X, y)
    self.coef_ = self.coef_.T

几个有趣的点:

  • 如果y的维度小于2,并没有并行操作。
  • 如果训练集X是稀疏矩阵,就用sparse_lsqr()求解,否则使用linalg.lstsq()

linalg.lstsq()

scipy.linalg.lstsq()方法就是用来计算X为非稀疏矩阵时的模型系数。这是使用普通的最小二乘OLS法来求解线性回归参数的。

  • scipy.linalg.lstsq()方法源码

    scipy提供了三种方法来求解least-squres problem最小均方问题,即模型优化目标。其提供了三个选项gelsd,gelsy,geless,这些参数传入了get_lapack_funcs()。这三个参数实际上是C函数名,函数是从LAPACK(Linear Algebra PACKage)中获得的。

    gelsd:它是用singular value decomposition of A and a divide and conquer method方法来求解线性回归方程参数的。

    gelsy:computes the minimum-norm solution to a real/complex linear least squares problem

    gelss:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A.

    scipy.linalg.lstsq()方法使用gelsd求解(并没有为用户提供选项)。

sparse_lsqr()方法源码

sqarse_lsqr()方法用来计算X是稀疏矩阵时的模型系数。sparse_lsqr()就是不同版本的scipy.sparse.linalg.lsqr(),参考自论文C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS实现。

相关源码如下:

    if sp_version < (0, 15):
# Backport fix for scikit-learn/scikit-learn#2986 / scipy/scipy#4142
from ._scipy_sparse_lsqr_backport import lsqr as sparse_lsqr
else:
from scipy.sparse.linalg import lsqr as sparse_lsqr

sklearn中LinearRegression使用及源码解读的更多相关文章

  1. 【原】Spark中Job的提交源码解读

    版权声明:本文为原创文章,未经允许不得转载. Spark程序程序job的运行是通过actions算子触发的,每一个action算子其实是一个runJob方法的运行,详见文章 SparkContex源码 ...

  2. HttpServlet中service方法的源码解读

    前言     最近在看<Head First Servlet & JSP>这本书, 对servlet有了更加深入的理解.今天就来写一篇博客,谈一谈Servlet中一个重要的方法-- ...

  3. 【原】 Spark中Task的提交源码解读

    版权声明:本文为原创文章,未经允许不得转载. 复习内容: Spark中Stage的提交 http://www.cnblogs.com/yourarebest/p/5356769.html Spark中 ...

  4. 【原】Spark中Stage的提交源码解读

    版权声明:本文为原创文章,未经允许不得转载. 复习内容: Spark中Job如何划分为Stage http://www.cnblogs.com/yourarebest/p/5342424.html 1 ...

  5. 【原】Spark不同运行模式下资源分配源码解读

    版权声明:本文为原创文章,未经允许不得转载. 复习内容: Spark中Task的提交源码解读 http://www.cnblogs.com/yourarebest/p/5423906.html Sch ...

  6. AbstractCollection类中的 T[] toArray(T[] a)方法源码解读

    一.源码解读 @SuppressWarnings("unchecked") public <T> T[] toArray(T[] a) { //size为集合的大小 i ...

  7. go中panic源码解读

    panic源码解读 前言 panic的作用 panic使用场景 看下实现 gopanic gorecover fatalpanic 总结 参考 panic源码解读 前言 本文是在go version ...

  8. go 中 sort 如何排序,源码解读

    sort 包源码解读 前言 如何使用 基本数据类型切片的排序 自定义 Less 排序比较器 自定义数据结构的排序 分析下源码 不稳定排序 稳定排序 查找 Interface 总结 参考 sort 包源 ...

  9. Mybatis源码解读-SpringBoot中配置加载和Mapper的生成

    本文mybatis-spring-boot探讨在springboot工程中mybatis相关对象的注册与加载. 建议先了解mybatis在spring中的使用和springboot自动装载机制,再看此 ...

随机推荐

  1. 【t073】&&【t015】魔法物品

    Time Limit: 1 second Memory Limit: 128 MB [问题描述] 有两种类型的物品:普通物品和魔法物品.每种普通物品有一个价值P,但每种魔法物品有两种价值:鉴定前的价值 ...

  2. spring sts4 如何添加tomcat 服务

    spring sts4 ide中已经没有集成tomcat运行服务器了,需要到点击Help-->Eclipse Marketplace中安装 Eclipse JST Server Adapters ...

  3. 十分钟了解 spring cloud

    1 为什么需要服务发现 简单来说,服务化的核心就是将传统的一站式应用根据业务拆分成一个一个的服务,而微服务在这个基础上要更彻底地去耦合(不再共享DB.KV,去掉重量级ESB),并且强调DevOps和快 ...

  4. Python3.7环境配置

    1.官网下载 https://www.python.org/ 我这是3.7.0 for windows executable installer Download Windows x86 web-ba ...

  5. android Notification分析—— 您可能会遇到各种问题

    使用的各种总结上线通知,csdn还有一个非常到位的总结,不这样做,反复总结,学生需要能够搜索自己或参考下面给出的链接. 研究开始时仔细阅读一些,今天,功能开发,一些问题和经验自己最近的遭遇给大家分享. ...

  6. [UWP]在应用开发中安全使用文件资源

    原文:[UWP]在应用开发中安全使用文件资源 在WPF或者UWP应用开发中,有时候会不可避免的需要操作文件系统(创建文件/目录),这时候有几个坑是需要大家注意下的. 创建文件或目录时的非法字符检测 在 ...

  7. python 合并两个排序的链表

    题目描述 输入两个单调递增的链表,输出两个链表合成后的链表,当然我们需要合成后的链表满足单调不减规则.   样例 给出 1->3->8->11->15->null,2-& ...

  8. c语言学习笔记(9)——指针

    指针是c语言的灵魂 ----------------------------------------------------------------------------- # include &l ...

  9. VS2015编译环境下CUDA安装配置

    CUDA下载 CUDA是NVIDIA推出的通用并行计算架构,该架构使GPU能够解决复杂的计算问题,CUDA只支持NVIDIA自家的显卡,过旧的版本型号也不被支持. 下载地址:https://devel ...

  10. Linux性能测试 iostat命令

    Linux系统出现了性能问题,一般我们可以通过top.iostat.free.vmstat等命令 来查看初步定位问题.其中iostat可以给我们提供丰富的IO状态数据.iostat 由 Red Hat ...