在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现。但是在看到源码的一瞬间突然有种怀疑人生的感觉,我是谁?我在哪?果然大佬的代码只能让我膜拜。

在一目十行地看完代码之后,我发现了一个问题,梯度的单词是gradient,一般在代码中会使用缩写grad 来表示梯度,而在这个代码中除了Gram 之外竟然没有一个以'g' 开头的单词,更不用说gradient 了。那么代码中包括注释压根没提到过梯度,是不是说明这里根本没有使用梯度下降算法呢,换言之就是是否还有其他方法来实现最小二乘法的线性回归呢?带着这个疑问,我开始仔细阅读LinearRegression.fit() 函数。

首先都是参数处理,这里虽然看不太懂但是也能大概知道他在做什么,所以可以跳过。

然后来到了核心代码,核心代码中使用的几个判断:

 1 if self.positive:
2 if y.ndim < 2:
3 pass
4 else:
5 pass
6 elif sp.issparse(X):
7 if y.ndim < 2:
8 pass
9 else:
10 pass
11 else:
12 pass

self.positive 是在使用密集矩阵的时候设置的参数,y.ndim 表示y 的维度,简单来说就是y 中有几个[],所以大概能知道代码将密集矩阵与稀疏矩阵区分开,并且将一维矩阵与多维矩阵区分开,意味着不同的类别使用不同的方法。

if self.positive 分段解析

1 if self.positive:
2 if y.ndim < 2:
3 self.coef_, self._residues = optimize.nnls(X, y)
4 else:
5 # scipy.optimize.nnls cannot handle y with shape (M, K)
6 outs = Parallel(n_jobs=n_jobs_)(
7 delayed(optimize.nnls)(X, y[:, j])
8 for j in range(y.shape[1]))
9 self.coef_, self._residues = map(np.vstack, zip(*outs))

可以看出y 的维度小于2 的话使用optimize.nnls() 方法,否则进行其他处理,因为“scipy.optimize.nnls cannot handle y with shape (M, K)”,但看到之后也调用了optimize.nnls,所以应该是将矩阵处理成可以使用的样子。

并且值得注意的是这里使用了Parallel(n_jobs=n_jobs_)(delayed(optimize.nnls)(X, y[:, j])for j in range(y.shape[1])) 的调用方式,也就是形如fun(x)(y) 的方式,这意味着函数内定义了另一个函数,第一个括号是fun 的参数,第二个括号是给fun 函数内定义的函数的参数。

elif sp.issparse(X) 分段解析

 1 elif sp.issparse(X):
2 X_offset_scale = X_offset / X_scale
3
4
5 def matvec(b):
6 return X.dot(b) - b.dot(X_offset_scale)
7
8
9 def rmatvec(b):
10 return X.T.dot(b) - X_offset_scale * np.sum(b)
11
12
13 X_centered = sparse.linalg.LinearOperator(shape=X.shape,
14 matvec=matvec,
15 rmatvec=rmatvec)
16
17 if y.ndim < 2:
18 out = sparse_lsqr(X_centered, y)
19 self.coef_ = out[0]
20 self._residues = out[3]
21 else:
22 # sparse_lstsq cannot handle y with shape (M, K)
23 outs = Parallel(n_jobs=n_jobs_)(
24 delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
25 for j in range(y.shape[1]))
26 self.coef_ = np.vstack([out[0] for out in outs])
27 self._residues = np.vstack([out[3] for out in outs])

可以看到先是对数据进行了处理,然后调用了sparse_lsqr() 函数。

剩余分段解析

1 else:
2 self.coef_, self._residues, self.rank_, self.singular_ = \
3 linalg.lstsq(X, y)
4 self.coef_ = self.coef_.T
5
6 if y.ndim == 1:
7 self.coef_ = np.ravel(self.coef_)
8 self._set_intercept(X_offset, y_offset, X_scale)

使用了linalg.lstsq() 函数。

以上我们可以看到代码中一共使用了3 个方法来实现线性回归:optimize.nnls()、sparse_lsqr()、linalg.lstsq()

optimize.nnls() 分析

NNLS 即非负正则化最小二乘法,代码实现由scipy.optimize.nnls提供,这里只是将其封装起来,在源码的注释中提到该算法的FORTRAN 代码在Charles L. Lawson 与Richard J. Hanson 两位教授于1987 年所著的《Solving Least Squares Problems》中发布。“The algorithm is an active set method. It solves the KKT (Karush-Kuhn-Tucker) conditions for the non-negative least squares problem.” 可惜由于本人水平有限,并不能从书中或者此处的代码中学到该算法的精髓,只能先挖一个坑,以后有所提高了再来研究该算法。

sparse_lsqr() 分析

LSQR 即最小二乘QR分解算法,代码实现由scipy.sparse.linalg.lsqr 类提供,这里只是将其封装起来,在文档中可以看到:

LSQR uses an iterative method to approximate the solution.  The number of iterations required to reach a certain accuracy depends strongly on the scaling of the problem.  Poor scaling of the rows or columns of A should therefore be avoided where possible.

同时也给出了参考文献:

[1] C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS 8(1), 43-71.
[2] C. C. Paige and M. A. Saunders (1982b). "Algorithm 583. LSQR: Sparse linear equations and least squares problems", ACM TOMS 8(2), 195-209.
[3] M. A. Saunders (1995). "Solution of sparse rectangular systems using LSQR and CRAIG", BIT 35, 588-604.

可以看到LSQR 算法是Paige 和Saunders 于1982 年提出的一种方法,但水平有限,暂时并不清楚其中原理。

linalg.lstsq() 分析

LSTSQ 是 LeaST SQuare (最小二乘)的意思,也就是普通的最小二乘法。代码实现由scipy.linalg.lstsq 提供,这里只是将其封装起来。通过官方文档提供的源代码链接,我找到了lstsq 函数的源代码,注释中提到了:

Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly faster on many problems. ``'gelss'`` was used historically. It is generally slow but uses less memory.

也就是说有三个选项:gelsd(默认推荐)、gelsy(可能稍快)、gelss(使用内存少),那么来看看他们分别使用什么方法来解决最小二乘法吧。

 1 if driver in ('gelss', 'gelsd'):
2 if driver == 'gelss':
3 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
4 v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
5 overwrite_a=overwrite_a,
6 overwrite_b=overwrite_b)
7
8 elif driver == 'gelsd':
9 if real_data:
10 lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
11 x, s, rank, info = lapack_func(a1, b1, lwork,
12 iwork, cond, False, False)
13 else: # complex data
14 lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
15 nrhs, cond)
16 x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
17 cond, False, False)
18 elif driver == 'gelsy':
19 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
20 jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
21 v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
22 lwork, False, False)

看到调用了lapack_func,但是找了一下spicy 并没有发现这个函数,于是搜索lapack_func 找到了定义:

lapack_func, lapack_lwork = get_lapack_funcs((driver,'%s_lwork' % driver),(a1, b1))

原来它调用了get_lapack_funcs 函数,于是去找该函数的文档,从文档中看到了该方法的用处:“This routine automatically chooses between Fortran/C interfaces. Fortran code is used whenever possible for arrays with column major order. In all other cases, C code is preferred.”,原来这是一个Fortran 语言和C 语言的接口且首选C 语言,并且从源码中我看到了它是调用_get_funcs 实现,那3 个单词是C 语言的函数名,所幸我对C 语言的熟悉程度打过Python,于是去找C 语言API 进行学习。

gelsdComputes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A and a divide and conquer method. 分治法的奇异值分解找出最优解。

gelssComputes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A. 使用奇异值分解找出最优解。

gelsy:Computes the minimum-norm solution to a linear least squares problem using a complete orthogonal factorization of A. 使用完全正交分解的方式找出最优解。

至此对于所有函数的分析基本上是结束了,密集矩阵使用的是NNLS 算法,稀疏矩阵使用的是LSQR 算法,其他的使用的则是最常用的最小二乘法算法。回到开头的问题,为什么没有使用梯度下降呢?难道梯度下降并不是解决最小二乘或者说是线性回归的算法吗?在网上查阅了很多材料之后我发现自己的问题:我陷入了误区。

先来说说最小二乘法,最小二乘法在我初高中的时候学习简单线性回归的时候就接触了,根据百度百科的词条,最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小,也就是说最小二乘法的公式是目标函数 = MIN( ∑ (预测值 – 实际值)² )。这就意味着如果人工计算的话就需要穷举所有的函数,计算他们的损失然后找出最小的那个函数。

再来说说梯度下降,梯度下降是迭代法的一种,通过更新梯度来找到损失最小的那个函数,不知你发现问题了吗?最小二乘法与梯度下降是在做同一件事情,也就是最优化问题,两个是并行的关系,并不存在谁解决谁。百度百科中关于梯度下降的词条中提到:“梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。”这里很清晰的指出了最小二乘法和梯度下降法的关系。

再来说说另一个我之前并不知道的概念:最小二乘准则。百度百科中提到这是一种对于偏差程度的评估准则,与上两者不同,上述的算法都是基于最小二乘准则提出的对于最小二乘法优化问题的解决方案,也就是如果不穷举的话如何找到最小二乘法的最优解。

列出我查阅的资料:

1、知乎-最小二乘法和梯度下降法有哪些区别?

2、百度百科-梯度下降

3、百度百科-最小二乘法

总结

1、虽然找到了sklearn.LinearRegression 类中对于线性回归的算法及实现,但发现并没有使用到梯度下降法,而是使用最小二乘法找到最优解,解开了我对最小二乘法与梯度下降到误解,但由于之前并未从事过算法研究与数学分析,对相应的算法一知半解,所以这里的代码难以看懂,只能就此作罢,学习了相应的算法之后再来学习代码实现。
2、在学习源码的过程中以及写这篇文章的过程中发现对于python的有些概念还是不太清晰,比如函数和方法还有fun(x)(y) 的调用方式,所以能看到上文中有些使用“函数”有些使用“方法”,可能并不对应,但之后熟悉了才能修改。
3、源码中的注释充斥着许多数学词汇,读起来让我异常头疼,几乎都要使用翻译软件才能理解,同时有些平常使用的词汇我也不懂,这个时候英语的作用就十分必要了。
4、总的来说,基础不扎实,水平不高,所以对于其中的精髓难以理解,同时可能文章中错漏百出,但我并未发现,这就是目前的问题或者说困境,勤加学习才能脱离。

机器学习03-sklearn.LinearRegression 源码学习的更多相关文章

  1. 【iScroll源码学习03】iScroll事件机制与滚动条的实现

    前言 想不到又到周末了,周末的时间要抓紧学习才行,前几天我们学习了iScroll几点基础知识: 1. [iScroll源码学习02]分解iScroll三个核心事件点 2. [iScroll源码学习01 ...

  2. Qt Creator 源码学习笔记03,大型项目如何管理工程

    阅读本文大概需要 6 分钟 一个项目随着功能开发越来越多,项目必然越来越大,工程管理成本也越来越高,后期维护成本更高.如何更好的组织管理工程,是非常重要的 今天我们来学习下 Qt Creator 是如 ...

  3. 【iScroll源码学习04】分离IScroll核心

    前言 最近几天我们前前后后基本将iScroll源码学的七七八八了,文章中未涉及的各位就要自己去看了 1. [iScroll源码学习03]iScroll事件机制与滚动条的实现 2. [iScroll源码 ...

  4. (转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)

    干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==& ...

  5. Vue源码学习(一):调试环境搭建

    最近开始学习Vue源码,第一步就是要把调试环境搭好,这个过程遇到小坑着实费了点功夫,在这里记下来 一.调试环境搭建过程 1.安装node.js,具体不展开 2.下载vue项目源码,git或svn等均可 ...

  6. JDK1.8源码分析01之学习建议(可以延伸其他源码学习)

    序言:目前有个计划就是准备看一下源码,来提升自己的技术实力.同时现在好多面试官都喜欢问源码,问你是否读过JDK源码等等? 针对如何阅读源码,也请教了我的老师.下面就先来看看老师的回答,也许会有帮助呢. ...

  7. Mybatis源码学习第八天(总结)

    源码学习到这里就要结束了; 来总结一下吧 Mybatis的总体架构 这次源码学习我们,学习了重点的模块,在这里我想说一句,源码的学习不是要所有的都学,一行一行的去学,这是错误的,我们只需要学习核心,专 ...

  8. Mybatis源码学习第六天(核心流程分析)之Executor分析

    今Executor这个类,Mybatis虽然表面是SqlSession做的增删改查,其实底层统一调用的是Executor这个接口 在这里贴一下Mybatis查询体系结构图 Executor组件分析 E ...

  9. [阿里DIN]从论文源码学习 之 embedding_lookup

    [阿里DIN]从论文源码学习 之 embedding_lookup 目录 [阿里DIN]从论文源码学习 之 embedding_lookup 0x00 摘要 0x01 DIN代码 1.1 Embedd ...

随机推荐

  1. Go的结构体

    目录 结构体 一.什么是结构体? 二.结构体的声明 三.创建结构体 1.创建有名结构体 2.结构体初始化 2.1 按位置传参 2.2 按关键字传 3.创建匿名结构体 四.结构体的类型 五.结构体的默认 ...

  2. 第43天学习打卡(JVM探究)

    JVM探究 请你谈谈你对JVM的理解?Java8虚拟机和之前的变化更新? 什么是OOM,什么是栈溢出StackOverFlowError? 怎么分析? JVM的常用调优参数有哪些? 内存快照如何抓取, ...

  3. HDOJ-4725(Dijikstra算法+拆点求最短路)

    The Shortest Path in Nya Graph HDOJ-4725 这题是关于最短路的问题,但是和常规的最短路有点不同的就是这里多了层次这一结构. 为了解决这一问题可以把每一层抽象或者划 ...

  4. JVM 中的StringTable

    是什么 字符串常量池是 JVM 中的一个重要结构,用于存储JVM运行时产生的字符串.在JDK7之前在方法区中,存储的是字符串常量.而字符串常量池在 JDK7 开始移入堆中,随之而来的是除了存储字符串常 ...

  5. 【Arduino学习笔记01】关于Arduino引脚的一些笔记

    参考链接:https://www.yiboard.com/thread-831-1-1.html Arduino Uno R3 - 引脚图 Arduino Uno R3 - 详细参数 Arduino ...

  6. swaks制作钓鱼邮件

      一.在网站:https://bccto.me/ 申请一个十分钟的邮箱 二.使用命令行,命令行解释如下: --from hacker@qq.com //发件人的邮箱 --ehlo qq.com // ...

  7. DataFocus小学堂|客户分析之复活客户分析

    复活客户分析 什么是"复活客户"?如何进行"复活客户分析"呢?今天,我们借助DataFocus系统,来了解一种简单的复活客户分析. 1.何为复活客户 复活客户, ...

  8. Codeforces Round #574 (Div. 2) E. OpenStreetMap 【单调队列】

    一.题目 OpenStreetMap 二.分析 对于二维空间找区间最小值,那么一维的很多好用的都无法用了,这里可以用单调队列进行查找. 先固定一个坐标,然后进行一维的单调队列操作,维护一个区间长度为$ ...

  9. Python中面向对象的概念

    1.语言的分类 1)面向机器 抽象成机器指令,机器容易理解.代表:汇编语言. 2)面向过程 做一件事,排除步骤,第一步做什么,第二步做什么,如果出现A问题,做什么处理,出现b问题,做什么处理.问题规模 ...

  10. incubator-dolphinscheduler 如何在不写任何新代码的情况下,能快速接入到prometheus和grafana中进行监控

    一.prometheus和grafana 简介 prometheus是由谷歌研发的一款开源的监控软件,目前已经贡献给了apache 基金会托管. 监控通常分为白盒监控和黑盒监控之分. 白盒监控:通过监 ...