spark机器学习从0到1决策树(六)

一、概念
决策树及其集合是分类和回归的机器学习任务的流行方法。 决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。 诸如随机森林和增强的树集合算法是分类和回归任务的最佳表现者。
决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。
二、基本原理
决策树学习通常包含三个方面:特征选择、决策树生成和决策树剪枝。决策树学习思想主要来源于:Quinlan在1986年提出的ID算法、在1993年提出的C4.5算法和Breiman等人在1984年提出的CART算法。
2.1、特征选择
特征选择在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率。通常特征选择的准则是信息增益(或信息增益比、基尼指数等),每次计算每个特征的信息增益,并比较它们的大小,选择信息增益最大(信息增益比最大、基尼指数最小)的特征。
那么问题来了:怎么找到这样的最优划分特征呢?如何来衡量最优?
什么是最优特征,通俗的理解是对训练数据具有很强的分类能力的特征,比如要看相亲的男女是否合适,他们的年龄差这个特征就远比他们的出生地重要,因为年龄差能更好得对相亲是否成功这个分类问题具有更强的分类能力。
但是计算机并不知道哪些特征是最优的,因此,就要找一个衡量特征是不是最优的指标,使得决策树在每一个分支上的数据尽可能属于同一类别的数据,即样本纯度最高。
我们用熵来衡量样本集合的纯度。
熵
这是概率统计与信息论中的一个概念,定义为:

其中p(x)=pi表示随机变量X发生概率。
我们可以从两个角度理解这个概念。
第一就是不确定度的一个度量,我们的目标是为了找到一颗树,使得每个分枝上都代表一个分类,也就是说我们希望这个分枝上的不确定性最小,即确定性最大,也就是这些数据都是同一个类别的。熵越小,代表这些数据是同一类别的越多。
第二个角度就是从纯度理解。因为熵是不确定度的度量,如果他们不确定度越小,意味着这个群体的差异很小,也就是它的纯度很高。比如,在明大的某富翁聚会上,来的人大多是某总,普通工薪白领就会很少,如果新来了一个刘总,他是富翁的确定性就很大,不确定性就很小,同时这个群体的纯度很大。总结来说就是熵越小,纯度越大,而我们希望的就是纯度越大越好。
信息增益
我们用信息熵来衡量一个分支的纯度,以及哪个特征是最优的特征
在决策树学习中应用信息增益准则来选择最优特征。信息增益定义如下:

特征A对训练数据集D的信息增益g(D,A) 等于D的不确定度H(D) 减去给定条件A下D的不确定度H(D|A),可以理解为由于特征A使得对数据集D的分类的不确定性减少的程度,信息增益大的特征具有更强的分类能力。
信息增益率
信息增益选择特征倾向于选择取值较多的特征,假设某个属性存在大量的不同值,决策树在选择属性时,将偏向于选择该属性,但这肯定是不正确(导致过拟合)的。因此有必要使用一种更好的方法,那就是信息增益率(Info Gain Ratio)来矫正这一问题。
其公式为:

其中

,n为特征A取值的个数
基尼指数
概率分布的基尼指数定义为

其中K表示分类问题中类别的个数
2.2、决策树的生成
从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点,再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增均很小或没有特征可以选择为止,最后得到一个决策树。
决策树需要有停止条件来终止其生长的过程。一般来说最低的条件是:当该节点下面的所有记录都属于同一类,或者当所有的记录属性都具有相同的值时。这两种条件是停止决策树的必要条件,也是最低的条件。在实际运用中一般希望决策树提前停止生长,限定叶节点包含的最低数据量,以防止由于过度生长造成的过拟合问题。
2.3、决策树的剪枝
决策树生成只考虑了通过信息增益或信息增益比来对训练数据更好的拟合,但没有考虑到如果模型过于复杂,会导致过拟合的产生。而剪枝就是缓解过拟合的一种手段,单纯的决策树生成学习局部的模型,而剪枝后的决策树会生成学习整体的模型,因为剪枝的过程中,通过最小化损失函数,可以平衡决策树的对训练数据的拟合程度和整个模型的复杂度。
决策树的损失函数定义如下:

其中,

三、代码实现
我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。
3.1、读取数据
首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。所以,我们把莺尾花的分类进行了一下改变,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
SparkConf conf = new SparkConf().setAppName("decisionTree").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
/**
* 读取数据
* 转化成 LabeledPoint类型
*/
JavaRDD<String> source = sc.textFile("data/mllib/iris.data");
JavaRDD<LabeledPoint> data = source.map(line->{
String[] parts = line.split(",");
double label = 0.0;
if(parts[4].equals("Iris-setosa")) {
label = 0.0;
}else if(parts[4].equals("Iris-versicolor")) {
label = 1.0;
}else {
label = 2.0;
}
return new LabeledPoint(label,Vectors.dense(Double.parseDouble(parts[0]),
Double.parseDouble(parts[1]),
Double.parseDouble(parts[2]),
Double.parseDouble(parts[3])));
});
3.2、划分数据集
接下来,首先进行数据集的划分,这里划分70%的训练集和30%的测试集:
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.7,0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
3.3、构建模型
调用决策树的trainClassifier方法构建决策树模型,设置参数,比如分类数、信息增益的选择、树的最大深度等:
int numClasses = 3;//分类数
int maxDepth = 5; //树的最大深度
int maxBins = 30;//离散连续特征时使用的bin数。增加maxBins允许算法考虑更多的分割候选者并进行细粒度的分割决策。
String impurity = "gini";
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<Integer,Integer>();//空的categoricalFeaturesInfo表示所有功能都是连续的。
DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
3.4、模型预测
接下来我们调用决策树模型的predict方法对测试数据集进行预测,并把模型结构打印出来:
JavaPairRDD<Double, Double> predictionAndLabel = testData.mapToPair(point->{
return new Tuple2<>(model.predict(point.features()),point.label());
});
//打印预测和实际结果
predictionAndLabel.foreach(x->{
System.out.println("predictionAndLabel:"+x);
});
System.out.println("Learned classification tree model:"+model.toDebugString());
/**
*控制台输出结果:
-----------------------
Learned classification tree model:DecisionTreeModel classifier of depth 5 with 15 nodes
If (feature 2 <= 2.45)
Predict: 0.0
Else (feature 2 > 2.45)
If (feature 2 <= 4.75)
Predict: 1.0
Else (feature 2 > 4.75)
If (feature 2 <= 4.95)
If (feature 0 <= 6.25)
If (feature 1 <= 3.05)
Predict: 2.0
Else (feature 1 > 3.05)
Predict: 1.0
Else (feature 0 > 6.25)
Predict: 1.0
Else (feature 2 > 4.95)
If (feature 3 <= 1.7000000000000002)
If (feature 0 <= 6.05)
Predict: 1.0
Else (feature 0 > 6.05)
Predict: 2.0
Else (feature 3 > 1.7000000000000002)
Predict: 2.0
------------------------
**/
3.5、准确性评估
最后,我们把模型预测的准确性打印出来:
double testErr = predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();
System.out.println("Test Error:"+testErr);
/**
*控制台输出结果:
------------------------------
Test Error:0.06976744186046512
------------------------------
**/
spark机器学习从0到1决策树(六)的更多相关文章
- spark机器学习从0到1介绍入门之(一)
一.什么是机器学习 机器学习(Machine Learning, ML)是一门多领域交叉学科,涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多门学科.专门研究计算机怎样模拟或实现人类的学习行 ...
- spark机器学习从0到1特征提取 TF-IDF(十二)
一.概念 “词频-逆向文件频率”(TF-IDF)是一种在文本挖掘中广泛使用的特征向量化方法,它可以体现一个文档中词语在语料库中的重要程度. 词语由t表示,文档由d表示,语料库由D表示.词频TF ...
- spark机器学习从0到1特征变换-标签和索引的转化(十六)
一.原理 在机器学习处理过程中,为了方便相关算法的实现,经常需要把标签数据(一般是字符串)转化成整数索引,或是在计算结束后将整数索引还原为相应的标签. Spark ML 包中提供了几个相关的转换器 ...
- spark机器学习从0到1机器学习工作流 (十一)
一.概念 一个典型的机器学习过程从数据收集开始,要经历多个步骤,才能得到需要的输出.这非常类似于流水线式工作,即通常会包含源数据ETL(抽取.转化.加载),数据预处理,指标提取,模型训练与交叉 ...
- spark机器学习从0到1特征选择-卡方选择器(十五)
一.公式 卡方检验的基本公式,也就是χ2的计算公式,即观察值和理论值之间的偏差 卡方检验公式 其中:A 为观察值,E为理论值,k为观察值的个数,最后一个式子实际上就是具体计算的方法了 n 为总 ...
- spark机器学习从0到1奇异值分解-SVD (七)
降维(Dimensionality Reduction) 是机器学习中的一种重要的特征处理手段,它可以减少计算过程中考虑到的随机变量(即特征)的个数,其被广泛应用于各种机器学习问题中,用于消除噪声 ...
- spark机器学习从0到1基本的统计工具之(三)
给定一个数据集,数据分析师一般会先观察一下数据集的基本情况,称之为汇总统计或者概要性统计.一般的概要性统计用于概括一系列观测值,包括位置或集中趋势(比如算术平均值.中位数.众数和四分位均值),展型 ...
- spark机器学习从0到1基本数据类型之(二)
MLlib支持存储在单个机器上的局部向量和矩阵,以及由一个或多个RDD支持的分布式矩阵. 局部向量和局部矩阵是用作公共接口的简单数据模型. 底层线性代数操作由Breeze提供. 在监督学习中使 ...
- spark机器学习从0到1特征抽取–Word2Vec(十四)
一.概念 Word2vec是一个Estimator,它采用一系列代表文档的词语来训练word2vecmodel.该模型将每个词语映射到一个固定大小的向量.word2vecmodel使用文档中每个词 ...
随机推荐
- Python selenium Chrome正在受到自动软件的控制 disable-infobars无效 的解决方法
问题解决 前两天更新了google浏览器版本,今天运行以前的脚本,发现options一个参数的配置不生效了. 运行了几次都发现该参数没有生效,也检查了自己的代码参数,没有写错,于是就有了这一波“网中寻 ...
- java 8 lambda表达式中的异常处理
目录 简介 处理Unchecked Exception 处理checked Exception 总结 java 8 lambda表达式中的异常处理 简介 java 8中引入了lambda表达式,lam ...
- Shutdown SpringBoot App
文章目录 Shutdown Endpoint close Application Context 退出SpringApplication 从外部程序kill App Shutdown SpringBo ...
- css之颜色表示法
css之颜色表示法 十六进制颜色 所有浏览器都支持十六进制颜色值. 十六进制颜色是这样规定的:#RRGGBB,其中的 RR(红色).GG(绿色).BB(蓝色)十六进制整数规定了颜色的成分.所有值必须介 ...
- Red 编程语言 2019 开发计划:全速前进!
开发四年只会写业务代码,分布式高并发都不会还做程序员? >>> Red 编程语言开发团队昨日发布了一篇 "Full steam ahead" 的文章,对其 2 ...
- SQL Server遍历表中记录的2种方法
SQL Server遍历表一般都要用到游标,SQL Server中可以很容易的用游标实现循环,实现SQL Server遍历表中记录.本文将介绍利用使用表变量和游标实现数据库中表的遍历. 表变量来实现表 ...
- 一个简单的wed服务器SHTTPD(1)————命令行和文件配置解析
开始学习<LInux网络编程>中的综合案例,虽然代码书上有,还是自己打一下加深理解和印象. 主要有两个函数,完成命令行的解析,另一个实现配置文件的解析,注释还是比较丰富的哦. //star ...
- Northwestern European Regional Contest 2014 Gym - 101482
Gym 101482C Cent Savings 简单的dp #include<bits/stdc++.h> #define inf 0x3f3f3f3f #define inf64 0x ...
- LTE网络概述
LTE主要由两部分组成:无线接入技术演进(E-UTRAN)+系统架构演进(SAE):其中,SAE主要含有的是演进型分组交换核心网(EPC),其控制处理部分为移动性管理实体(MME),数据承载部分称为业 ...
- 一步一步教你PowerBI数据分析:制作客户RFM数据分析
客户分析就是根据客户信息数据来分析客户特征,评估客户价值,从而为客户制订相应的营销策略与资源配置.通过合理.系统的客户分析,企业可以知道不同的客户有着什么样的需求,分析客户消费特征与商务效益的关系,使 ...