一、概念

决策树及其集合是分类和回归的机器学习任务的流行方法。 决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。 诸如随机森林和增强的树集合算法是分类和回归任务的最佳表现者。

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

二、基本原理

决策树学习通常包含三个方面:特征选择、决策树生成和决策树剪枝。决策树学习思想主要来源于:Quinlan在1986年提出的ID算法、在1993年提出的C4.5算法和Breiman等人在1984年提出的CART算法。

2.1、特征选择

特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。

那么问题来了:怎么找到这样的最优划分特征呢?如何来衡量最优?

什么是最优特征,通俗的理解是对训练数据具有很强的分类能力的特征,比如要看相亲的男女是否合适,他们的年龄差这个特征就远比他们的出生地重要,因为年龄差能更好得对相亲是否成功这个分类问题具有更强的分类能力。

但是计算机并不知道哪些特征是最优的,因此,就要找一个衡量特征是不是最优的指标,使得决策树在每一个分支上的数据尽可能属于同一类别的数据,即样本纯度最高。

我们用熵来衡量样本集合的纯度。

这是概率统计与信息论中的一个概念,定义为:

 

其中p(x)=pi表示随机变量X发生概率。

我们可以从两个角度理解这个概念。

第一就是不确定度的一个度量,我们的目标是为了找到一颗树,使得每个分枝上都代表一个分类,也就是说我们希望这个分枝上的不确定性最小,即确定性最大,也就是这些数据都是同一个类别的。熵越小,代表这些数据是同一类别的越多。

第二个角度就是从纯度理解。因为熵是不确定度的度量,如果他们不确定度越小,意味着这个群体的差异很小,也就是它的纯度很高。比如,在明大的某富翁聚会上,来的人大多是某总,普通工薪白领就会很少,如果新来了一个刘总,他是富翁的确定性就很大,不确定性就很小,同时这个群体的纯度很大。总结来说就是熵越小,纯度越大,而我们希望的就是纯度越大越好。

信息增益

我们用信息熵来衡量一个分支的纯度,以及哪个特征是最优的特征
在决策树学习中应用信息增益准则来选择最优特征。信息增益定义如下:

 
信息增益

特征A对训练数据集D的信息增益g(D,A) 等于D的不确定度H(D) 减去给定条件A下D的不确定度H(D|A),可以理解为由于特征A使得对数据集D的分类的不确定性减少的程度,信息增益大的特征具有更强的分类能力。

信息增益率

信息增益选择特征倾向于选择取值较多的特征,假设某个属性存在大量的不同值,决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。因此有必要使用一种更好的方法,那就是信息增益率(Info Gain Ratio)来矫正这一问题。
其公式为:

 
信息增益率

其中

 
训练数据集D关于特征A的值的熵

,n为特征A取值的个数

基尼指数

概率分布的基尼指数定义为

 
基尼指数

其中K表示分类问题中类别的个数

2.2、决策树的生成

从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。

决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。

2.3、决策树的剪枝

决策树生成只考虑了通过信息增益或信息增益比来对训练数据更好的拟合,但没有考虑到如果模型过于复杂,会导致过拟合的产生。而剪枝就是缓解过拟合的一种手段,单纯的决策树生成学习局部的模型,而剪枝后的决策树会生成学习整体的模型,因为剪枝的过程中,通过最小化损失函数,可以平衡决策树的对训练数据的拟合程度和整个模型的复杂度。

决策树的损失函数定义如下:

 
损失函数

其中,

 
图1

三、代码实现

我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

3.1、读取数据

首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
SparkConf conf = new SparkConf().setAppName("decisionTree").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); /**
* 读取数据
* 转化成 LabeledPoint类型
*/
JavaRDD<String> source = sc.textFile("data/mllib/iris.data");
JavaRDD<LabeledPoint> data = source.map(line->{
String[] parts = line.split(",");
double label = 0.0;
if(parts[4].equals("Iris-setosa")) {
label = 0.0;
}else if(parts[4].equals("Iris-versicolor")) {
label = 1.0;
}else {
label = 2.0;
}
return new LabeledPoint(label,Vectors.dense(Double.parseDouble(parts[0]),
Double.parseDouble(parts[1]),
Double.parseDouble(parts[2]),
Double.parseDouble(parts[3])));
});

3.2、划分数据集

接下来,首先进行数据集的划分,这里划分70%的训练集和30%的测试集:

JavaRDD<LabeledPoint>[] splits =  data.randomSplit(new double[] {0.7,0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

3.3、构建模型

调用决策树的trainClassifier方法构建决策树模型,设置参数,比如分类数、信息增益的选择、树的最大深度等:

int numClasses = 3;//分类数
int maxDepth = 5; //树的最大深度
int maxBins = 30;//离散连续特征时使用的bin数。增加maxBins允许算法考虑更多的分割候选者并进行细粒度的分割决策。
String impurity = "gini";
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<Integer,Integer>();//空的categoricalFeaturesInfo表示所有功能都是连续的。
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

3.4、模型预测

接下来我们调用决策树模型的predict方法对测试数据集进行预测,并把模型结构打印出来:

JavaPairRDD<Double, Double> predictionAndLabel =  testData.mapToPair(point->{
return new Tuple2<>(model.predict(point.features()),point.label());
});
//打印预测和实际结果
predictionAndLabel.foreach(x->{
System.out.println("predictionAndLabel:"+x);
});
System.out.println("Learned classification tree model:"+model.toDebugString());
/**
*控制台输出结果:
-----------------------
Learned classification tree model:DecisionTreeModel classifier of depth 5 with 15 nodes
If (feature 2 <= 2.45)
Predict: 0.0
Else (feature 2 > 2.45)
If (feature 2 <= 4.75)
Predict: 1.0
Else (feature 2 > 4.75)
If (feature 2 <= 4.95)
If (feature 0 <= 6.25)
If (feature 1 <= 3.05)
Predict: 2.0
Else (feature 1 > 3.05)
Predict: 1.0
Else (feature 0 > 6.25)
Predict: 1.0
Else (feature 2 > 4.95)
If (feature 3 <= 1.7000000000000002)
If (feature 0 <= 6.05)
Predict: 1.0
Else (feature 0 > 6.05)
Predict: 2.0
Else (feature 3 > 1.7000000000000002)
Predict: 2.0
------------------------
**/

3.5、准确性评估

最后,我们把模型预测的准确性打印出来:

double testErr = predictionAndLabel.filter(pl ->  !pl._1().equals(pl._2())).count() / (double)  testData.count();
System.out.println("Test Error:"+testErr);
/**
*控制台输出结果:
------------------------------
Test Error:0.06976744186046512
------------------------------
**/

spark机器学习从0到1决策树(六)的更多相关文章

  1. spark机器学习从0到1介绍入门之(一)

      一.什么是机器学习 机器学习(Machine Learning, ML)是一门多领域交叉学科,涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多门学科.专门研究计算机怎样模拟或实现人类的学习行 ...

  2. spark机器学习从0到1特征提取 TF-IDF(十二)

        一.概念 “词频-逆向文件频率”(TF-IDF)是一种在文本挖掘中广泛使用的特征向量化方法,它可以体现一个文档中词语在语料库中的重要程度. 词语由t表示,文档由d表示,语料库由D表示.词频TF ...

  3. spark机器学习从0到1特征变换-标签和索引的转化(十六)

      一.原理 在机器学习处理过程中,为了方便相关算法的实现,经常需要把标签数据(一般是字符串)转化成整数索引,或是在计算结束后将整数索引还原为相应的标签. Spark ML 包中提供了几个相关的转换器 ...

  4. spark机器学习从0到1机器学习工作流 (十一)

        一.概念 一个典型的机器学习过程从数据收集开始,要经历多个步骤,才能得到需要的输出.这非常类似于流水线式工作,即通常会包含源数据ETL(抽取.转化.加载),数据预处理,指标提取,模型训练与交叉 ...

  5. spark机器学习从0到1特征选择-卡方选择器(十五)

      一.公式 卡方检验的基本公式,也就是χ2的计算公式,即观察值和理论值之间的偏差   卡方检验公式 其中:A 为观察值,E为理论值,k为观察值的个数,最后一个式子实际上就是具体计算的方法了 n 为总 ...

  6. spark机器学习从0到1奇异值分解-SVD (七)

      降维(Dimensionality Reduction) 是机器学习中的一种重要的特征处理手段,它可以减少计算过程中考虑到的随机变量(即特征)的个数,其被广泛应用于各种机器学习问题中,用于消除噪声 ...

  7. spark机器学习从0到1基本的统计工具之(三)

      给定一个数据集,数据分析师一般会先观察一下数据集的基本情况,称之为汇总统计或者概要性统计.一般的概要性统计用于概括一系列观测值,包括位置或集中趋势(比如算术平均值.中位数.众数和四分位均值),展型 ...

  8. spark机器学习从0到1基本数据类型之(二)

        MLlib支持存储在单个机器上的局部向量和矩阵,以及由一个或多个RDD支持的分布式矩阵. 局部向量和局部矩阵是用作公共接口的简单数据模型. 底层线性代数操作由Breeze提供. 在监督学习中使 ...

  9. spark机器学习从0到1特征抽取–Word2Vec(十四)

      一.概念 Word2vec是一个Estimator,它采用一系列代表文档的词语来训练word2vecmodel.该模型将每个词语映射到一个固定大小的向量.word2vecmodel使用文档中每个词 ...

随机推荐

  1. 2019-2020-1 20199329《Linux内核原理与分析》第二周作业

    <Linux内核原理与分析>第二周作业 一.上周问题总结: 未能及时整理笔记 Linux还需要多用 markdown格式不熟练 发布博客时间超过规定期限 二.本周学习内容: <庖丁解 ...

  2. Python 基础教程(第二版)笔记 (1)

    P22 除非对 input 有特别的需要,否则应该尽可能使用 raw_input 函数. 长字符串,跨多行,用三个引号代替普通引号.并且不需要使用反斜线进行转义. P23 原始字符串 print r' ...

  3. Linux中的常用符号

    >, 1>     输出重定向符stdout,代码为1,重定向内容到文件,清除已有的内容,然后加入新内容,如果文件不存在还会创建文件 >>, 1>>   追加输出重 ...

  4. 《Android游戏开发详解》一1.7 控制流程第1部分——if和else语句

    本节书摘来异步社区<Android游戏开发详解>一书中的第1章,第1.7节,译者: 李强 责编: 陈冀康,更多章节内容可以访问云栖社区"异步社区"公众号查看. 1.7 ...

  5. Vue项目中jQuery的引入

    1.安装jQuery依赖 npm install jquery --save-dev 2.在webpack.base.conf.js头部加入如下代码 var webpack = require(&qu ...

  6. 图论-网络流-Dinic (邻接表版)

    //RQ的板子真的很好用 #include<cstdio> #include<cstring> #include<queue> #define INF 1e9 us ...

  7. Centos7.x 装机优化

    Linux 服务器装机后优化 参考 https://blog.csdn.net/u010133338/article/details/81055475 优化初始化脚本 vim init_optimiz ...

  8. CSS躬行记(9)——网格布局

    网格布局(Grid Layout)也叫栅格布局,与表格布局类似,也依赖行和列.但与之不同的是,网格布局能直接控制HTML文档中元素的顺序.位置和大小等,而不用再借助辅助元素. 一.术语 下图展示了CS ...

  9. LateX的简单字体设置(颜色,居中,大小等)

    \(\color{red}{Ⅰ.文本单行居中}\) $$\text{我是蒟蒻}$$ \[\text{我是蒟蒻} \] \(\color{Black}{Ⅱ.设置字体颜色}\) $$\color{Purp ...

  10. [js进阶1]-数据类型

    基本数据类型 js 总的有7中数据类型,包括基本类型和引用类型 基本类型 6 种 number boolean string null undefiend symbol 前5种类型统称为原始类型 sy ...