机器学习03-sklearn.LinearRegression 源码学习
在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现。但是在看到源码的一瞬间突然有种怀疑人生的感觉,我是谁?我在哪?果然大佬的代码只能让我膜拜。
在一目十行地看完代码之后,我发现了一个问题,梯度的单词是gradient,一般在代码中会使用缩写grad 来表示梯度,而在这个代码中除了Gram 之外竟然没有一个以'g' 开头的单词,更不用说gradient 了。那么代码中包括注释压根没提到过梯度,是不是说明这里根本没有使用梯度下降算法呢,换言之就是是否还有其他方法来实现最小二乘法的线性回归呢?带着这个疑问,我开始仔细阅读LinearRegression.fit() 函数。
首先都是参数处理,这里虽然看不太懂但是也能大概知道他在做什么,所以可以跳过。
然后来到了核心代码,核心代码中使用的几个判断:
1 if self.positive:
2 if y.ndim < 2:
3 pass
4 else:
5 pass
6 elif sp.issparse(X):
7 if y.ndim < 2:
8 pass
9 else:
10 pass
11 else:
12 pass
self.positive 是在使用密集矩阵的时候设置的参数,y.ndim 表示y 的维度,简单来说就是y 中有几个[],所以大概能知道代码将密集矩阵与稀疏矩阵区分开,并且将一维矩阵与多维矩阵区分开,意味着不同的类别使用不同的方法。
if self.positive 分段解析
1 if self.positive:
2 if y.ndim < 2:
3 self.coef_, self._residues = optimize.nnls(X, y)
4 else:
5 # scipy.optimize.nnls cannot handle y with shape (M, K)
6 outs = Parallel(n_jobs=n_jobs_)(
7 delayed(optimize.nnls)(X, y[:, j])
8 for j in range(y.shape[1]))
9 self.coef_, self._residues = map(np.vstack, zip(*outs))
可以看出y 的维度小于2 的话使用optimize.nnls() 方法,否则进行其他处理,因为“scipy.optimize.nnls cannot handle y with shape (M, K)”,但看到之后也调用了optimize.nnls,所以应该是将矩阵处理成可以使用的样子。
并且值得注意的是这里使用了Parallel(n_jobs=n_jobs_)(delayed(optimize.nnls)(X, y[:, j])for j in range(y.shape[1])) 的调用方式,也就是形如fun(x)(y) 的方式,这意味着函数内定义了另一个函数,第一个括号是fun 的参数,第二个括号是给fun 函数内定义的函数的参数。
elif sp.issparse(X) 分段解析
1 elif sp.issparse(X):
2 X_offset_scale = X_offset / X_scale
3
4
5 def matvec(b):
6 return X.dot(b) - b.dot(X_offset_scale)
7
8
9 def rmatvec(b):
10 return X.T.dot(b) - X_offset_scale * np.sum(b)
11
12
13 X_centered = sparse.linalg.LinearOperator(shape=X.shape,
14 matvec=matvec,
15 rmatvec=rmatvec)
16
17 if y.ndim < 2:
18 out = sparse_lsqr(X_centered, y)
19 self.coef_ = out[0]
20 self._residues = out[3]
21 else:
22 # sparse_lstsq cannot handle y with shape (M, K)
23 outs = Parallel(n_jobs=n_jobs_)(
24 delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
25 for j in range(y.shape[1]))
26 self.coef_ = np.vstack([out[0] for out in outs])
27 self._residues = np.vstack([out[3] for out in outs])
可以看到先是对数据进行了处理,然后调用了sparse_lsqr() 函数。
剩余分段解析
1 else:
2 self.coef_, self._residues, self.rank_, self.singular_ = \
3 linalg.lstsq(X, y)
4 self.coef_ = self.coef_.T
5
6 if y.ndim == 1:
7 self.coef_ = np.ravel(self.coef_)
8 self._set_intercept(X_offset, y_offset, X_scale)
使用了linalg.lstsq() 函数。
以上我们可以看到代码中一共使用了3 个方法来实现线性回归:optimize.nnls()、sparse_lsqr()、linalg.lstsq()
optimize.nnls() 分析
NNLS 即非负正则化最小二乘法,代码实现由scipy.optimize.nnls提供,这里只是将其封装起来,在源码的注释中提到该算法的FORTRAN 代码在Charles L. Lawson 与Richard J. Hanson 两位教授于1987 年所著的《Solving Least Squares Problems》中发布。“The algorithm is an active set method. It solves the KKT (Karush-Kuhn-Tucker) conditions for the non-negative least squares problem.” 可惜由于本人水平有限,并不能从书中或者此处的代码中学到该算法的精髓,只能先挖一个坑,以后有所提高了再来研究该算法。
sparse_lsqr() 分析
LSQR 即最小二乘QR分解算法,代码实现由scipy.sparse.linalg.lsqr 类提供,这里只是将其封装起来,在文档中可以看到:
LSQR uses an iterative method to approximate the solution. The number of iterations required to reach a certain accuracy depends strongly on the scaling of the problem. Poor scaling of the rows or columns of A should therefore be avoided where possible.
同时也给出了参考文献:
[1] C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS 8(1), 43-71.
[2] C. C. Paige and M. A. Saunders (1982b). "Algorithm 583. LSQR: Sparse linear equations and least squares problems", ACM TOMS 8(2), 195-209.
[3] M. A. Saunders (1995). "Solution of sparse rectangular systems using LSQR and CRAIG", BIT 35, 588-604.
可以看到LSQR 算法是Paige 和Saunders 于1982 年提出的一种方法,但水平有限,暂时并不清楚其中原理。
linalg.lstsq() 分析
LSTSQ 是 LeaST SQuare (最小二乘)的意思,也就是普通的最小二乘法。代码实现由scipy.linalg.lstsq 提供,这里只是将其封装起来。通过官方文档提供的源代码链接,我找到了lstsq 函数的源代码,注释中提到了:
Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly faster on many problems. ``'gelss'`` was used historically. It is generally slow but uses less memory.
也就是说有三个选项:gelsd(默认推荐)、gelsy(可能稍快)、gelss(使用内存少),那么来看看他们分别使用什么方法来解决最小二乘法吧。
1 if driver in ('gelss', 'gelsd'):
2 if driver == 'gelss':
3 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
4 v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
5 overwrite_a=overwrite_a,
6 overwrite_b=overwrite_b)
7
8 elif driver == 'gelsd':
9 if real_data:
10 lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
11 x, s, rank, info = lapack_func(a1, b1, lwork,
12 iwork, cond, False, False)
13 else: # complex data
14 lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
15 nrhs, cond)
16 x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
17 cond, False, False)
18 elif driver == 'gelsy':
19 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
20 jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
21 v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
22 lwork, False, False)
看到调用了lapack_func,但是找了一下spicy 并没有发现这个函数,于是搜索lapack_func 找到了定义:
lapack_func, lapack_lwork = get_lapack_funcs((driver,'%s_lwork' % driver),(a1, b1))
原来它调用了get_lapack_funcs 函数,于是去找该函数的文档,从文档中看到了该方法的用处:“This routine automatically chooses between Fortran/C interfaces. Fortran code is used whenever possible for arrays with column major order. In all other cases, C code is preferred.”,原来这是一个Fortran 语言和C 语言的接口且首选C 语言,并且从源码中我看到了它是调用_get_funcs 实现,那3 个单词是C 语言的函数名,所幸我对C 语言的熟悉程度打过Python,于是去找C 语言API 进行学习。
gelsd:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A and a divide and conquer method. 分治法的奇异值分解找出最优解。
gelss:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A. 使用奇异值分解找出最优解。
gelsy:Computes the minimum-norm solution to a linear least squares problem using a complete orthogonal factorization of A. 使用完全正交分解的方式找出最优解。
至此对于所有函数的分析基本上是结束了,密集矩阵使用的是NNLS 算法,稀疏矩阵使用的是LSQR 算法,其他的使用的则是最常用的最小二乘法算法。回到开头的问题,为什么没有使用梯度下降呢?难道梯度下降并不是解决最小二乘或者说是线性回归的算法吗?在网上查阅了很多材料之后我发现自己的问题:我陷入了误区。
先来说说最小二乘法,最小二乘法在我初高中的时候学习简单线性回归的时候就接触了,根据百度百科的词条,最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小,也就是说最小二乘法的公式是目标函数 = MIN( ∑ (预测值 – 实际值)² )。这就意味着如果人工计算的话就需要穷举所有的函数,计算他们的损失然后找出最小的那个函数。
再来说说梯度下降,梯度下降是迭代法的一种,通过更新梯度来找到损失最小的那个函数,不知你发现问题了吗?最小二乘法与梯度下降是在做同一件事情,也就是最优化问题,两个是并行的关系,并不存在谁解决谁。百度百科中关于梯度下降的词条中提到:“梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。”这里很清晰的指出了最小二乘法和梯度下降法的关系。
再来说说另一个我之前并不知道的概念:最小二乘准则。百度百科中提到这是一种对于偏差程度的评估准则,与上两者不同,上述的算法都是基于最小二乘准则提出的对于最小二乘法优化问题的解决方案,也就是如果不穷举的话如何找到最小二乘法的最优解。
列出我查阅的资料:
总结
机器学习03-sklearn.LinearRegression 源码学习的更多相关文章
- 【iScroll源码学习03】iScroll事件机制与滚动条的实现
前言 想不到又到周末了,周末的时间要抓紧学习才行,前几天我们学习了iScroll几点基础知识: 1. [iScroll源码学习02]分解iScroll三个核心事件点 2. [iScroll源码学习01 ...
- Qt Creator 源码学习笔记03,大型项目如何管理工程
阅读本文大概需要 6 分钟 一个项目随着功能开发越来越多,项目必然越来越大,工程管理成本也越来越高,后期维护成本更高.如何更好的组织管理工程,是非常重要的 今天我们来学习下 Qt Creator 是如 ...
- 【iScroll源码学习04】分离IScroll核心
前言 最近几天我们前前后后基本将iScroll源码学的七七八八了,文章中未涉及的各位就要自己去看了 1. [iScroll源码学习03]iScroll事件机制与滚动条的实现 2. [iScroll源码 ...
- (转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)
干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==& ...
- Vue源码学习(一):调试环境搭建
最近开始学习Vue源码,第一步就是要把调试环境搭好,这个过程遇到小坑着实费了点功夫,在这里记下来 一.调试环境搭建过程 1.安装node.js,具体不展开 2.下载vue项目源码,git或svn等均可 ...
- JDK1.8源码分析01之学习建议(可以延伸其他源码学习)
序言:目前有个计划就是准备看一下源码,来提升自己的技术实力.同时现在好多面试官都喜欢问源码,问你是否读过JDK源码等等? 针对如何阅读源码,也请教了我的老师.下面就先来看看老师的回答,也许会有帮助呢. ...
- Mybatis源码学习第八天(总结)
源码学习到这里就要结束了; 来总结一下吧 Mybatis的总体架构 这次源码学习我们,学习了重点的模块,在这里我想说一句,源码的学习不是要所有的都学,一行一行的去学,这是错误的,我们只需要学习核心,专 ...
- Mybatis源码学习第六天(核心流程分析)之Executor分析
今Executor这个类,Mybatis虽然表面是SqlSession做的增删改查,其实底层统一调用的是Executor这个接口 在这里贴一下Mybatis查询体系结构图 Executor组件分析 E ...
- [阿里DIN]从论文源码学习 之 embedding_lookup
[阿里DIN]从论文源码学习 之 embedding_lookup 目录 [阿里DIN]从论文源码学习 之 embedding_lookup 0x00 摘要 0x01 DIN代码 1.1 Embedd ...
随机推荐
- 区分函数防抖&函数节流
1. 概念区分 函数防抖:触发事件后,在n秒内函数只能执行一次,如果触发事件后在n秒内又触发了事件,则会重新计算函数延执行时间. 简单说: 频繁触发, 但只在特定的时间内才执行一次代码,如果特定时间内 ...
- fastjson 漏洞利用 命令执行
目录 1. 准备一个Payload 2. 服务器上启动 rmi 3. 向目标注入payload 参考 如果你已经用DNSLog之类的工具,探测到了某个url有fastjson问题,那么接着可以试试能不 ...
- 企业安全_检测SQL注入的一些方式探讨
目录 寻找SQL注入点的 way MySQL Inject 入门案例 自动化审计的尝试之旅 人工审计才能保证精度 寻找SQL注入点的 way 在企业中有如下几种方式可以选择: 自动化 - 白盒基于源码 ...
- 【HTB系列】 靶机Swagshop的渗透测试详解
出品|MS08067实验室(www.ms08067.com) 本文作者:是大方子(Ms08067实验室核心成员) 总结与反思 使用vi提权 magento漏洞的利用 magescan 工具的使用 靶机 ...
- 使用 SVG transform rotate 解决画框中的数字跟随旋转的问题
问题描述 在图片上画框标注数字,旋转画布后,数字随之旋转,可读性不强,要求修改成无论画布怎么旋转,数字都是正向显示~ 原交互图示: 解决方案 先看下 dom 的结构 然后看下下面简单的代码 // 获取 ...
- Python flask-restful框架讲解
Restful 是 Flask 的扩展,增加了对快速构建 REST api 的支持.它是一个轻量级的概念,与您现有的 ORM/librarie 一起工作.Restful 鼓励最小化设置的最佳实践.如果 ...
- MySQL 多表查询与事务的操作
表连接查询 什么是多表查询 # 数据准备 # 多表查询的作用 * 比如:我们想查询孙悟空的名字和他所在的部门的名字,则需要使用多表查询 # 如果一条 SQL 语句查询多张表,因为查询结果在多张不同的表 ...
- Python开发环境从零搭建-02-代码编辑器Sublime
想要从零开始搭建一个Python的开发环境说容易也容易 说难也能难倒一片开发人员,在接下来的一系列视频中,会详细的讲解如何一步步搭建python的开发环境 本文章是搭建环境的第2篇 讲解的内容是:安装 ...
- LAB1 启动操作系统
从机器上电到运行OS发生了什么? 在电脑主板上有一个Flash块,存放了BIOS的可执行代码.它是ROM,断电不会丢掉数据.在机器上电的时候,CPU要求内存控制器从0地址读取数据(程序第一条指令)的时 ...
- 使用nodejs进行了简单的文件分卷工具
关键词:node fs readline generator (在这之前需要声明的是这篇博客的应用范围应该算是相当狭隘,写出来主要也就是给自己记录一下临时兴起写的一个小工具,仅从功能需求上来说我相信是 ...