前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spark ML中,机器学习问题从单机到分布式转换的核心方法。
        单机时代,如果我们想解决一个机器学习的优化问题,最重要的就是根据训练数据,计算损失函数和梯度。由于是单机环境,什么都好说,只要公式推导没错,浮点数计算溢出问题解决好,就好了。但是,当我们的训练数据量足够大,大到单机根本存储不下的时候,对分布式学习的需求就出现了。比如电商数据,动辄上亿的训练数据量,单机望尘莫及,只能求助于分布式计算。
        那么问题来了,在分布式计算中,怎样计算得到损失函数的值,以及它的梯度值呢?这就涉及到Spark ml的一个核心,用八个字概括就是,模型集中,计算分布。具体来说,比如我们要学习一个逻辑回归模型,它的训练数据可能是存储在成百上千台服务器上,但具体的模型,只集中于一台服务器上。每次迭代时,我们现在训练数据所在的服务器上,并行的计算出,每个服务器包含的训练数据,所对应的损失函数值和梯度值,然后把这些信息集中在模型所在的机器上,进行合并,总结出所有训练数据的损失函数值和梯度值,然后对所学习的参数进行迭代,并把参数分发给拥有训练数据的服务器,并进入下一个迭代循环,直到模型收敛。
        如此看来,分布式机器学习也没有什么特别的,核心问题就在于,怎样把每个服务器上计算的损失函数值和地图值集中到模型所在的服务器上,除此之外,跟单机的机器学习问题并没有什么不同。
        这一步,在Spark ML中是如何实现的呢?这里要隆重介绍一个函数,treeAggregate,在我看来,这个函数是从单机到分布式机器学习的核心,理解了这个函数,分布式机器学习问题,就理解大半了。
        treeAggregate函数主要做什么呢?它负责把每一台服务器上的信息进行聚合,然后汇总给模型所在的服务器。拥有训练数据的服务器,可能动辄成千上万,这么多数据怎样聚合起来呢?其实函数名字已经有暗示了,它用的是树形聚合方法。假设我们有32台服务器,如果使用线性聚合,也就是说,1跟2合并,结果再跟3合并,这样一共需要进行31次合并,而且每次合并还不能并行进行,因此treeAggregate采用的方法是,把32个节点分配到一颗二叉树的32个叶子节点,然后从叶子节点开始一层一层的聚合,这样只需要5次聚合就可以了。
        具体的,使用treeAggregate函数需要定义两种运算,分别是seqOp和combOp,前者的作用是,把一个训练样本加入已有的统计,即对损失函数值和梯度进行更新,后者的作用是,把两个统计信息合并起来,可以这样理解,前者主要在单机上的统计计算时起作用,后者主要是在不同服务器进行数据合并时起作用。
        有了这些核心概念,就可以进入optim目录去一探究竟了,optim目录是Spark ML跟优化相关内容的代码库,它主要包含三部分,一是aggregator目录,二是loss目录,三是根目录,下面我们逐一介绍。
        aggregator目录下存放的是,聚合相关的代码。我们知道在机器学习任务中,不同的任务需要聚合的信息是不一样的。这里就为我们实现了几个最基本的聚合操作。其中,DifferentiableLossAggregator是基类,顾名思义,实现了最基本的可微损失函数的聚合,实际上的聚合操作都是由它的子类完成的,基类中定义了通用的merge操作,具体的add操作由各子类自己定义,代码实现都比较直接,就不一一介绍了,感兴趣的朋友可以直接读源码。
        loss目录下存放的是,损失函数相关的代码。其实,最一般性的损失函数是在breeze库中定义的,这个等我们在介绍breeze库的时候再细说。loss目录下有两个文件,一个是DifferentiableRegularization.scala,这里是把正则也当作一种损失,主要包含L2正则,另一个是RDDLossFunction.scala,这个就非常重要了,它就是应用treeAggregate函数,从单机的损失+梯度,汇总到分布式版的损失+梯度的函数,它主要应用了aggregate目录下的聚合类实现分布式的聚合运算。
        根目录下主要包含了几个优化问题的解法,最基础的是NormalEquationSolver.scala,它主要描述了一个最小二乘的标准解法,也就是正规方程的解法,其次是WeightedLeastSquares.scala,它解决了一个带权值的最小二乘问题,利用了正规方程解法,最后是IterativelyReweightedLeastSquares.scala,这是在解逻辑斯蒂回归等一大类一般性线性回归问题中常用的IRLS算法,利用了带权值的最小二乘解法。
        好,今天的介绍就到这里了。作者也是初学者,欢迎大家批评指正。

Spark ML源码分析之二 从单机到分布式的更多相关文章

  1. Spark ML源码分析之一 设计框架解读

    本博客为作者原创,如需转载请注明参考           在深入理解Spark ML中的各类算法之前,先理一下整个库的设计框架,是非常有必要的,优秀的框架是对复杂问题的抽象和解剖,对这种抽象的学习本身 ...

  2. Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...

  3. Spark ML源码分析之三 分类器

            前面跟大家扯了这么多废话,终于到具体的机器学习模型了.大部分机器学习的教程,总要从监督学习开始讲起,而监督学习的众多算法当中,又以分类算法最为基础,原因在于分类问题非常的单纯直接,几乎 ...

  4. spark 源码分析之二十一 -- Task的执行流程

    引言 在上两篇文章 spark 源码分析之十九 -- DAG的生成和Stage的划分 和 spark 源码分析之二十 -- Stage的提交 中剖析了Spark的DAG的生成,Stage的划分以及St ...

  5. 第十一篇:Spark SQL 源码分析之 External DataSource外部数据源

    上周Spark1.2刚发布,周末在家没事,把这个特性给了解一下,顺便分析下源码,看一看这个特性是如何设计及实现的. /** Spark SQL源码分析系列文章*/ (Ps: External Data ...

  6. 第十篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 query

    /** Spark SQL源码分析系列文章*/ 前面讲到了Spark SQL In-Memory Columnar Storage的存储结构是基于列存储的. 那么基于以上存储结构,我们查询cache在 ...

  7. 第九篇:Spark SQL 源码分析之 In-Memory Columnar Storage源码分析之 cache table

    /** Spark SQL源码分析系列文章*/ Spark SQL 可以将数据缓存到内存中,我们可以见到的通过调用cache table tableName即可将一张表缓存到内存中,来极大的提高查询效 ...

  8. 第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

    /** Spark SQL源码分析系列文章*/ 接上一篇文章Spark SQL Catalyst源码分析之Physical Plan,本文将介绍Physical Plan的toRDD的具体实现细节: ...

  9. 第一篇:Spark SQL源码分析之核心流程

    /** Spark SQL源码分析系列文章*/ 自从去年Spark Submit 2013 Michael Armbrust分享了他的Catalyst,到至今1年多了,Spark SQL的贡献者从几人 ...

随机推荐

  1. Python接口测试,Requests模块讲解:GET、POST、Cookies、Session等

    文章最下方有对应课程的视频链接哦^_^ 一.安装.GET,公共方法 二.POST 三.Cookies 四.Session 五.认证 六.超时配置.代理.事件钩子 七.错误异常

  2. 1.5 PCI-X总线简介

    PCI-X总线仍采用并行总线技术.PCI-X总线使用的大多数总线事务基于PCI总线,但是在实现细节上略有不同.PCI-X总线将工作频率提高到533MHz,并首先引入了PME(Power Managem ...

  3. 用DDK开发的9054驱动 .

    和S5933比较起来,开发PLX9054比较不幸,可能是第一次开发PCI的缘故吧.因为,很多PCI的例子都是对S5933,就连微软出版的<Programming the Microsoft Wi ...

  4. java注解之二

    从JDK5开始,Java增加了Annotation(注解),Annotation是代码里的特殊标记,这些标记可以在编译.类加载.运行时被读取,并执行相应的处理.通过使用Annotation,开发人员可 ...

  5. memcache 查看memcache的运行状态

    memcache的运行状态可以方便的用 stats 命令显示. 首先用telnet 127.0.0.1 11211这样的命令连接上memcache,然后直接输入stats就可以得到当前memcache ...

  6. C#:几种数据库的大数据批量插入

    在之前只知道SqlServer支持数据批量插入,殊不知道Oracle.SQLite和MySql也是支持的,不过Oracle需要使用Orace.DataAccess驱动,今天就贴出几种数据库的批量插入解 ...

  7. (八)java垃圾回收和收尾

    垃圾回收机制:当一个对象不再被引用时,或者说当一个对象的引用不存在时,我们就认为该对象不再被需要,它所占用的内存就会被释放掉.     垃圾回收只是在程序执行过程中偶尔发生,java不同的运行时刻会产 ...

  8. Linux显示所有可更新的软件清单命令

    Linux显示所有可更新的软件清单命令 youhaidong@youhaidong-ThinkPad-Edge-E545:~$ yum check-update 程序"yum"尚未 ...

  9. Error:dojo.data.ItemFileWriteStore:Invalid item argument

    1.错误描述 dijit.form.ComboBox TypeError:_4e is undefined                                            Sea ...

  10. web开发性能优化---安全篇

    1.权限管理 从模块.表单.数据审核.功能按钮全面数据安全验证及管理. 2.ip验证 数据接口访问进行IP校验 3.登录.操作日志.程序安全日志  系统所有用户登录.操作全部日志记录. 程序安全日志操 ...