之前有文章介绍过决策树(ID3)。简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的。按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分过于迅速的问题。ID3算法还不能处理连续性特征。
下面简单介绍一下其他算法:

CART 分类回归树

CART是Classification And Regerssion Trees的缩写,既能处理分类任务也能做回归任务。

CART树的典型代表时二叉树,根据不同的条件将分类。

CART树构建算法
与ID3决策树的构建方法类似,直接给出CART树的构建过程。首先与ID3类似采用字典树的数据结构,包含以下4中元素:

  • 待切分的特征
  • 待切分的特征值
  • 右子树。当不再需要切分的时候,也可以是单个值
  • 左子树,类似右子树。

过程如下:

  1. 寻找最合适的分割特征
  2. 如果不能分割数据集,该数据集作为一个叶子节点。
  3. 对数据集进行二分割
  4. 对分割的数据集1重复1, 2,3 步,创建右子树。
  5. 对分割的数据集2重复1, 2,3 步,创建左子树。

明显的递归算法。

通过数据过滤的方式分割数据集,返回两个子集。

def splitDatas(rows, value, column):
# 根据条件分离数据集(splitDatas by value, column)
# return 2 part(list1, list2) list1 = []
list2 = [] if isinstance(value, int) or isinstance(value, float):
for row in rows:
if row[column] >= value:
list1.append(row)
else:
list2.append(row)
else:
for row in rows:
if row[column] == value:
list1.append(row)
else:
list2.append(row)
return list1, list2
复制代码

划分数据点

创建二进制决策树本质上就是递归划分输入空间的过程。

代码如下:

# gini()
def gini(rows):
# 计算gini的值(Calculate GINI) length = len(rows)
results = calculateDiffCount(rows)
imp = 0.0
for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
复制代码

构建树

def buildDecisionTree(rows, evaluationFunction=gini):
# 递归建立决策树, 当gain=0,时停止回归
# build decision tree bu recursive function
# stop recursive function when gain = 0
# return tree
currentGain = evaluationFunction(rows)
column_lenght = len(rows[0])
rows_length = len(rows) best_gain = 0.0
best_value = None
best_set = None # choose the best gain
for col in range(column_lenght - 1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1, list2 = splitDatas(rows, value, col)
p = len(list1) / rows_length
gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col, value)
best_set = (list1, list2)
dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
#
# stop or not stop if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
else:
return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
复制代码

上面代码的功能是先找到数据集切分的最佳位置和分割数据集。之后通过递归构建出上面图片的整棵树。

剪枝

在决策树的学习中,有时会造成决策树分支过多,这是就需要去掉一些分支,降低过度拟合。通过决策树的复杂度来避免过度拟合的过程称为剪枝。
后剪枝需要从训练集生成一棵完整的决策树,然后自底向上对非叶子节点进行考察。利用测试集判断是否将该节点对应的子树替换成叶节点。
代码如下:

def prune(tree, miniGain, evaluationFunction=gini):
# 剪枝 when gain < mini Gain, 合并(merge the trueBranch and falseBranch)
if tree.trueBranch.results == None:
prune(tree.trueBranch, miniGain, evaluationFunction)
if tree.falseBranch.results == None:
prune(tree.falseBranch, miniGain, evaluationFunction) if tree.trueBranch.results != None and tree.falseBranch.results != None:
len1 = len(tree.trueBranch.data)
len2 = len(tree.falseBranch.data)
len3 = len(tree.trueBranch.data + tree.falseBranch.data) p = float(len1) / (len1 + len2) gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data) if gain < miniGain:
tree.data = tree.trueBranch.data + tree.falseBranch.data
tree.results = calculateDiffCount(tree.data)
tree.trueBranch = None
tree.falseBranch = None
复制代码

当节点的gain小于给定的 mini Gain时则合并这两个节点.。

最后是构建树的代码:

if __name__ == '__main__':
dataSet = loadCSV()
decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini)
prune(decisionTree, 0.4)
test_data = [5.9,3,4.2,1.5]
r = classify(test_data, decisionTree)
print(r)
复制代码

可以打印decisionTree可以构建出如如上的图片中的决策树。
后面找一组数据测试看能否得到正确的分类。

完整代码和数据集请查看:
github:CART

总结:

  • CART决策树
  • 分割数据集
  • 递归创建树

参考文章:
CART分类回归树分析与python实现
CART决策树(Decision Tree)的Python源码实现

机器学习之分类回归树(python实现CART)的更多相关文章

  1. 分类-回归树模型(CART)在R语言中的实现

    分类-回归树模型(CART)在R语言中的实现 CART模型 ,即Classification And Regression Trees.它和一般回归分析类似,是用来对变量进行解释和预测的工具,也是数据 ...

  2. 秒懂机器学习---分类回归树CART

    秒懂机器学习---分类回归树CART 一.总结 一句话总结: 用决策树来模拟分类和预测,那些人还真是聪明:其实也还好吧,都精通的话想一想,混一混就好了 用决策树模拟分类和预测的过程:就是对集合进行归类 ...

  3. 机器学习技法-决策树和CART分类回归树构建算法

    课程地址:https://class.coursera.org/ntumltwo-002/lecture 重要!重要!重要~ 一.决策树(Decision Tree).口袋(Bagging),自适应增 ...

  4. 分类回归树(CART)

    概要 本部分介绍 CART,是一种非常重要的机器学习算法.   基本原理   CART 全称为 Classification And Regression Trees,即分类回归树.顾名思义,该算法既 ...

  5. CART(分类回归树)

    1.简单介绍 线性回归方法可以有效的拟合所有样本点(局部加权线性回归除外).当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法一个是困难一个是笨拙.此外,实际中很多问题为非线性的,例如常 ...

  6. 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...

  7. 利用CART算法建立分类回归树

    常见的一种决策树算法是ID3,ID3的做法是每次选择当前最佳的特征来分割数据,并按照该特征所有可能取值来切分,也就是说,如果一个特征有四种取值,那么数据将被切分成4份,一旦按某特征切分后,该特征在之后 ...

  8. CART决策树(分类回归树)分析及应用建模

    一.CART决策树模型概述(Classification And Regression Trees)   决策树是使用类似于一棵树的结构来表示类的划分,树的构建可以看成是变量(属性)选择的过程,内部节 ...

  9. 决策树的剪枝,分类回归树CART

    决策树的剪枝 决策树为什么要剪枝?原因就是避免决策树“过拟合”样本.前面的算法生成的决策树非常的详细而庞大,每个属性都被详细地加以考虑,决策树的树叶节点所覆盖的训练样本都是“纯”的.因此用这个决策树来 ...

随机推荐

  1. Python魔法缓存,以数字开始

    Python魔法缓存,以数字开始 众所周知,Python是弱类型的脚本语言,变量的定义是不用声明类型的. a = 1 Python所有数字的本质都是对象, 他们是不可改变的数据类型,这意味着改变数字数 ...

  2. Linux服务器 上传/下载 文档/目录

    1.从服务器上下载文件 scp username@servername:/path/filename /var/www/local_dir(本地目录) 例如scp root@192.168.0.101 ...

  3. Java队列学习第一篇之列介绍

    Java并发之显式锁和隐式锁的区别 在面试的过程中有可能会问到:在Java并发编程中,锁有两种实现:使用隐式锁和使用显示锁分别是什么?两者的区别是什么?所谓的显式锁和隐式锁的区别也就是说说Synchr ...

  4. 对象中属性 加锁 用:volatile 关键词修饰 而 不用 synchronized 加锁

    一个对象中有一个状态 属性,现在业务需求 存在多线程来修改 和 拿去 这个状态 的值,这种情况如果加锁怎么加? 一种是 在 set 和get 这个状态的 方法那加 synchronized . 还有一 ...

  5. 数据结构和算法(Golang实现)(21)排序算法-插入排序

    插入排序 插入排序,一般我们指的是简单插入排序,也可以叫直接插入排序.就是说,每次把一个数插到已经排好序的数列里面形成新的排好序的数列,以此反复. 插入排序属于插入类排序算法. 除了我以外,有些人打扑 ...

  6. Android 开发小零碎

    1.EditText默认就会自动获取焦点, 如何让EditText不自动获取焦点? 解决之道:在EditText的父级控件中找一个,设置成 android:focusable="true&q ...

  7. Java方法的重点

    方法就是完成功能一个语句集合体 使用方法的原则:方法的原子性,一个方法只实现一个功能. 方法的重载 1.函数名必须相同 2.形参列表必须不同(可以是个数不同,类型不同,不然完全一样) 3.返回值可以相 ...

  8. L19深度学习中的优化问题和凸性介绍

    优化与深度学习 优化与估计 尽管优化方法可以最小化深度学习中的损失函数值,但本质上优化方法达到的目标与深度学习的目标并不相同. 优化方法目标:训练集损失函数值 深度学习目标:测试集损失函数值(泛化性) ...

  9. C++学习--编译优化

    常量折叠 把常量表达式的值求出来作为常量嵌在最终生成的代码中. 疑问:对于一个很复杂的常量表达式,编译器会算出结果再编译吗?亦或者是把这个表达式完全翻译成机器码,最终留给程序去解决? 分情况: 涉及的 ...

  10. Windows环境下搭建Cocos2d-x3.2环境并配置android交叉编译环境

    一.软件 1)VS2012(C++11特性在VS2012以上可以使用):传送门: 2)Cocos2d-x官网源码:传送门:http://cocos2d-x.org/download 3)JDK:传送门 ...