MLlib之NaiveBayes算法源码学习
package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
/**
* Model for Naive Bayes Classifiers.
*
* @param labels list of labels
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
{
// Need to put an extra pair of braces to prevent Scala treating `i` as a member.
var i = 0
while (i < theta.length) {
var j = 0
while (j < theta(i).length) {
brzTheta(i, j) = theta(i)(j)
j += 1
}
i += 1
}
}
override def predict(testData: RDD[Vector]): RDD[Double] = {
val bcModel = testData.context.broadcast(this)
testData.mapPartitions { iter =>
val model = bcModel.value
iter.map(model.predict)
}
}
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}
}
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
def this() = this(1.0)
/** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
this.lambda = lambda
this
}
/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector =>
sv.values
case dv: DenseVector =>
dv.values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
}
}
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
requireNonnegativeValues(v)
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
requireNonnegativeValues(v)
(c._1 + 1L, c._2 += v.toBreeze)
},
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
val numLabels = aggregated.length
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
pi(i) = math.log(n + lambda) - piLogDenom
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
}
new NaiveBayesModel(labels, pi, theta)
}
}
/**
* Top-level methods for calling naive Bayes.
*/
object NaiveBayes {
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* This version of the method uses a default smoothing parameter of 1.0.
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
*/
def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
new NaiveBayes().run(input)
}
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
}
package org.apache.spark.mllib.classification import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD /**
* :: Experimental ::
* Represents a classification model that predicts to which of a set of categories an example
* belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
*/
@Experimental
trait ClassificationModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
* @return an RDD[Double] where each entry contains the corresponding prediction
*/
def predict(testData: RDD[Vector]): RDD[Double] /**
* Predict values for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return predicted category from the trained model
*/
def predict(testData: Vector): Double /**
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
朴素贝叶斯分类算法
何为分类算法?简单来说,就是将具有某些特性的物体归类对应到一个已知的类别集合中的某个类别上。从数学角度来说,可以做如下定义:
已知集合: C={y1,y2,..,yn} 和 I={x1,x2,..,xm,..} ,确定映射规则 y=f(x),使得任意 xi∈I 有且仅有一个 yj∈C 使得 yj=f(xi) 成立。
其中,C为类别集合,I为待分类的物体,f则为分类器,分类算法的主要任务就是构造分类器f。
分类算法的构造通常需要一个已知类别的集合来进行训练,通常来说训练出来的分类算法不可能达到100%的准确率。分类器的质量往往与训练数据、验证数据、训练数据样本大小等因素相关。
举个例子,我们日常生活中看到一个陌生人,要做的第一件事情就是判断其性别,判断性别的过程就是一个分类的过程。根据以往的生活经验,通常经过头发长短、服饰和体型这三个要素就能判断出来一个人的性别。这里的“生活经验”就是一个训练好的关于性别判断的模型,其训练数据是日常生活中遇到的形形色色的人。突然有一天,一个娘炮走到了你面前,长发飘飘,穿着紧身的衣裤,可是体型却很man,于是你就疑惑了,根据以往的经验——也就是已经训练好的模型,无法判断这个人的性别。于是你学会了通过喉结来判断其性别,这样你的模型被训练的质量更高了。但不可否认的是,永远会出现一个让你无法判断性别的人。所以模型永远无法达到100%的准确,只会随着训练数据的不断增多而无限接近100%的准确。
贝叶斯公式
贝叶斯公式,或者叫做贝叶斯定理,是贝叶斯分类的基础。而贝叶斯分类是一类分类算法的统称,这一类算法的基础都是贝叶斯公式。目前研究较多的四种贝叶斯分类算法有:Naive Bayes、TAN、BAN和GBN。
理工科的学生在大学应该都学过概率论,其中最重要的几个公式中就有贝叶斯公式——用来描述两个条件概率之间的关系,比如P(A|B)和P(B|A)。如何在已知事件A和B分别发生的概率,和事件B发生时事件A发生的概率,来求得事件A发生时事件B发生的概率,这就是贝叶斯公式的作用。其表述如下:
P(B|A)=P(A|B)×P(B)P(A)
朴素贝叶斯分类
朴素贝叶斯分类,Naive Bayes,你也可以叫它NB算法。其核心思想非常简单:对于某一预测项,分别计算该预测项为各个分类的概率,然后选择概率最大的分类为其预测分类。就好像你预测一个娘炮是女人的可能性是40%,是男人的可能性是41%,那么就可以判断他是男人。
Naive Bayes的数学定义如下:
- 设 x={a1,a2,..,am} 为一个待分类项,而每个 ai 为 x 的一个特征属性
- 已知类别集合 C={y1,y2,..,yn}
- 计算 x 为各个类别的概率: P(y1|x),P(y2|x),..,P(yn|x)
- 如果 P(yk|x)=max{P(y1|x),P(y2|x),..,P(yn|x)} ,则 x 的类别为 yk
如何获取第四步中的最大值,也就是如何计算第三步中的各个条件概率最为重要。可以采用如下做法:
- 获取训练数据集,即分类已知的数据集
- 统计得到在各类别下各个特征属性的条件概率估计,即:P(a1|y1),P(a2|y1),...,P(am|y1);P(a1|y2),P(a2|y2),...,P(am|y2);...;P(a1|yn),P(a2|yn),...,P(am|yn),其中的数据可以是离散的也可以是连续的
- 如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导: P(yi|x)=P(x|yi)P(yi)P(x)
对于某x来说,分母是固定的,所以只要找出分子最大的即为条件概率最大的。又因为各特征属性是条件独立的,所以有: P(x|yi)P(yi)=P(a1|yi)P(a2|yi)...P(am|yi)P(yi)=P(yi)∏mj=1P(aj|yi)
Additive smoothing
Additive smoothing,又叫Laplacian smoothing或Lidstone smoothing。
当某个类别下某个特征项划分没有出现时, P(ai|yj)=0 ,这样最后乘出来的结果会让精确度大大的降低,所以引入Additive smoothing来解决这个问题。其思想是对于这样等于0的情况,将其计数值加1,这样如果训练样本集数量充分大时,并不会对结果产生影响,并且解决了上述频率为0的尴尬局面。
MLlib之NaiveBayes算法源码学习的更多相关文章
- MLlib之LR算法源码学习
/** * :: DeveloperApi :: * GeneralizedLinearModel (GLM) represents a model trained using * Generaliz ...
- [算法1-排序](.NET源码学习)& LINQ & Lambda
[算法1-排序](.NET源码学习)& LINQ & Lambda 说起排序算法,在日常实际开发中我们基本不在意这些事情,有API不用不是没事找事嘛.但必要的基础还是需要了解掌握. 排 ...
- [算法2-数组与字符串的查找与匹配] (.NET源码学习)
[算法2-数组与字符串的查找与匹配] (.NET源码学习) 关键词:1. 数组查找(算法) 2. 字符串查找(算法) 3. C#中的String(源码) 4. 特性Attribute 与内 ...
- Java集合专题总结(1):HashMap 和 HashTable 源码学习和面试总结
2017年的秋招彻底结束了,感觉Java上面的最常见的集合相关的问题就是hash--系列和一些常用并发集合和队列,堆等结合算法一起考察,不完全统计,本人经历:先后百度.唯品会.58同城.新浪微博.趣分 ...
- Redis源码学习:字符串
Redis源码学习:字符串 1.初识SDS 1.1 SDS定义 Redis定义了一个叫做sdshdr(SDS or simple dynamic string)的数据结构.SDS不仅用于 保存字符串, ...
- 基于jdk1.8的HashMap源码学习笔记
作为一种最为常用的容器,同时也是效率比较高的容器,HashMap当之无愧.所以自己这次jdk源码学习,就从HashMap开始吧,当然水平有限,有不正确的地方,欢迎指正,促进共同学习进步,就是喜欢程序员 ...
- Vue源码学习1——Vue构造函数
Vue源码学习1--Vue构造函数 这是我第一次正式阅读大型框架源码,刚开始的时候完全不知道该如何入手.Vue源码clone下来之后这么多文件夹,Vue的这么多方法和概念都在哪,完全没有头绪.现在也只 ...
- zookeeper集群搭建及Leader选举算法源码解析
第一章.zookeeper概述 一.zookeeper 简介 zookeeper 是一个开源的分布式应用程序协调服务器,是 Hadoop 的重要组件. zooKeeper 是一个分布式的,开放源码的分 ...
- Vue2.1.7源码学习
原本文章的名字叫做<源码解析>,不过后来想想,还是用“源码学习”来的合适一点,在没有彻底掌握源码中的每一个字母之前,“解析”就有点标题党了.建议在看这篇文章之前,最好打开2.1.7的源码对 ...
随机推荐
- Sping Cloud hystrix.stream 自动发现-监控
相关组件安装脚本 [root@java_gateway4 java_tps]# cat cront_install.sh #!/bin/bashyum install jq -ymkdir /home ...
- AltiumDesigner PCB中栅格与格点的切换
PCB中通过快捷键Ctrl+G,进入设置界面. 在弹出的对话框中,在Display,Coarse选择Lines为栅格,Dots为格点,Do Not Draw为无任何显示.
- 151. Reverse Words in a String翻转一句话中的单词
[抄题]: Given an input string, reverse the string word by word. Example: Input: "the sky is blue& ...
- [leetcode]80. Remove Duplicates from Sorted Array II有序数组去重(单个元素可出现两次)
Given a sorted array nums, remove the duplicates in-place such that duplicates appeared at most twic ...
- 完全理解 Python 迭代对象、迭代器、生成器
完全理解 Python 迭代对象.迭代器.生成器 2017/05/29 · 基础知识 · 9 评论 · 可迭代对象, 生成器, 迭代器 分享到: 原文出处: liuzhijun 本文源自RQ作者 ...
- python历史与基本类型
前言 我自学的方式主要是看文档,看视频,第一次做写博客这么神圣的事情,内心是忐忑的,写的东西比较杂,路过的小伙伴不要嘲笑我,主要是记录一日所学,顺便锻炼一下语言组织能力吧,anyway,这些都不重要, ...
- input 随笔
1,input 点击出现蓝色外边框 解决:outline:none
- Python中单引号,双引号,3个单引号及3个双引号的区别
单引号和双引号在Python中我们都知道单引号和双引号都可以用来表示一个字符串,比如 str1 = 'python' str2 = "python" str1和str2是没有任何区 ...
- oracle 区分大小写遇到的坑
1. oracle 字段是区分大小写的 ..在navicat 中使用查询 select REMAIN_PRINCIPAl from T_NF_PROJECT; navicat 默认会把 REMA ...
- background-clip 和 background-origin 的区别
background-origin:指定绘制背景图片的起点. background-clip:是对背景图片的剪裁,指定背景图片的显示范围. 1.background-origin:padding | ...