Introduction:

分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。

Algorithm:

step . 分别计算所有特征中各个分类的基尼系数

step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 

step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致

Formula:

Code:

 """
Created on Thu Jan 30 15:36:39 2014 @filename: test.py
""" import cart c = cart.Cart()
c.trainDecisionTree('decision_tree_text.txt')
print c.trainresult

view test.py

 # -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 18:05:22 2014 @filename: cart.py
"""
FILENAME = 'decision_tree_text.txt'
MAXDEPTH = 10 import numpy as np
import plottree class Cart():
def __init__(self):
self.trainresult = 'WARNING : please trainDecisionTree first!'
pass def trainDecisionTree(self, filename):
self.__loadDataSet(filename)
self.__optimalTree(self.__datamat) def __loadDataSet(self, filename):
fread = open(filename)
self.__dataset = np.array([row.strip().split('\t') \
for row in fread.readlines()])
self.__textdic = {}
for col in self.__dataset.T:
i = .0
for cell in col:
if not self.__textdic.has_key(cell):
self.__textdic[cell] = i
i += 1
self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
for cell in row]) \
for row in self.__dataset]) def __getSampleCount(self, setd, col = -1, s = None):
dic = {} if s is not None:
newset = self.__getSampleMat(setd,col,s)[:,-1]
else:
newset = setd[:,col] for cell in newset:
if not dic.has_key(cell):
dic[cell] = 1.
else:
dic[cell] += 1
return dic def __getSampleMat(self, setd, col, s):
lista = []; listb = []
for row in setd:
if row[col] == s:
lista.append(row)
else:
listb.append(row)
return np.array(lista), np.array(listb) def __getGiniD(self, setd):
sample_count = self.__getSampleCount(setd)
gini = 0
for item in sample_count.items():
gini += item[1]/len(setd) * (1- item[1]/len(setd))
return gini def __getGiniDA(self, setd, a):
sample_count = self.__getSampleCount(setd, a)
dic = {}
for item in sample_count.items():
setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
(1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
dic[item[0]]=gini
return min(dic.items()), dic def __optimalNode(self, setd):
coln = 0
ginicol = 0
mingini = {1:1}
for col in setd[:,:-1].T:
gini, dic = self.__getGiniDA(setd, coln)
if gini[1] < mingini[1]:
mingini = gini
ginicol = coln
coln += 1
return ginicol, mingini[0], mingini[1] def __optimalNodeText(self, col, value):
row = 0
tex = None
for cell in self.__dataset.T[col]:
if self.__datamat[row,col] == value:
tex = cell
break
row += 1
return tex def __optimalTree(self, setd):
arr = setd
count = MAXDEPTH-1
features = np.array(range(len(arr.T)))
lst = []
defaultc = None
while count > 0:
count -= 1
ginicol, value, gini = self.__optimalNode(arr)
parts = self.__getSampleMat(arr, ginicol, value)
args = [np.unique(part[:,-1]) for part in parts]
realvalues = [np.unique(part[:,ginicol])[0] for part in parts]
realcol = features[ginicol]
features = np.delete(features, ginicol)
if gini == 0 or len(arr.T) == 2:
if args[0] == defaultc:
value = realvalues[0]
else:
value = realvalues[1]
self.trainresult = self.__buildList(lst, realcol, value, gini)
self.__plotTree(self.trainresult)
return
if len(args[0]) == 1:
defaultc = args[0]
self.__buildList(lst, realcol, realvalues[0], gini)
arr = np.concatenate((parts[1][:,:ginicol], \
parts[1][:,ginicol+1:]), axis=1)
else:
defaultc = args[1]
self.__buildList(lst, realcol, realvalues[1], gini)
arr = np.concatenate((parts[0][:,:ginicol], \
parts[0][:,ginicol+1:]), axis=1) def __plotTree(self, lst):
dic = {}
for item in lst:
if dic == {}:
dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
else:
dic = {item[0]:{item[1]:'c1','ELSE':dic}}
tree = plottree.retrieveTree(dic)
self.trainresult = tree
plottree.createPlot(tree) def __buildList(self, lst, col, value, gini):
print 'feature col:', col, \
' feature val:', self.__optimalNodeText(col, value), \
' Gini:', gini, '\n'
lst.insert(0,[col,str(self.__optimalNodeText(col, \
value))+':'+str(value)])
return lst if __name__ == '__main__':
cart = Cart()

view cart.py

 # -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 11:45:18 2014 @filename: plottree.py
""" import matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "1.0")
arrow_args = dict(arrowstyle = "<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va= "center",\
ha = "center", bbox = nodeType, arrowprops = arrow_args) def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ is 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \
{'yes': 'c1','no': 'c2'}}}}):
return dic def plotMidText(centrPt, parentPt, txtString):
xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = myTree.keys()[0]
centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff]
plotMidText(centrPt, parentPt, nodeTxt)
plotNode(firstStr, centrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
centrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
myTree = retrieveTree()
createPlot(myTree)

view plottree.py

输入数据

输出结果

feature col: 2  feature val: 是  Gini: 0.266666666667 

feature col: 1  feature val: 是  Gini: 0.0 

Reference:

Harrington P. Machine Learning in Action

李航. 统计学习方法

IMPLEMENTED IN PYTHON +1 | CART生成树的更多相关文章

  1. Python实现CART(基尼指数)

    Python实现CART(基尼指数) 运行环境 Pyhton3 treePlotter模块(画图所需,不画图可不必) matplotlib(如果使用上面的模块必须) 计算过程 st=>start ...

  2. 机器学习之分类回归树(python实现CART)

    之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...

  3. Algorithm: quick sort implemented in python 算法导论 快速排序

    import random def partition(A, lo, hi): pivot_index = random.randint(lo, hi) pivot = A[pivot_index] ...

  4. leetcode-happy number implemented in python

    视频分析: http://v.youku.com/v_show/id_XMTMyODkyNDA0MA==.html?from=y1.7-1.2 class Solution(object): def ...

  5. Awesome Python

    Awesome Python  A curated list of awesome Python frameworks, libraries, software and resources. Insp ...

  6. Python开源框架、库、软件和资源大集合

    A curated list of awesome Python frameworks, libraries, software and resources. Inspired by awesome- ...

  7. Python 库汇总英文版

    Awesome Python  A curated list of awesome Python frameworks, libraries, software and resources. Insp ...

  8. Python框架、库以及软件资源汇总

    转自:http://developer.51cto.com/art/201507/483510.htm 很多来自世界各地的程序员不求回报的写代码为别人造轮子.贡献代码.开发框架.开放源代码使得分散在世 ...

  9. Python Scopes and Namespaces

    Before introducing classes, I first have to tell you something about Python's scope rules. Class def ...

随机推荐

  1. yii columns value and type and checkbox columns

    value  I am here type  I am here checkbox columns   useage

  2. Linux read/write fread/fwrite两者区别

    Linux read/write fread/fwrite两者区别 1,fread是带缓冲的,read不带缓冲. 2,fopen是标准c里定义的,open是POSIX中定义的. 3,fread可以读一 ...

  3. Day3 - Python基础3 函数、递归、内置函数

    Python之路,Day3 - Python基础3   本节内容 1. 函数基本语法及特性 2. 参数与局部变量 3. 返回值 嵌套函数 4.递归 5.匿名函数 6.函数式编程介绍 7.高阶函数 8. ...

  4. retrofit2 okhttp3 RxJava butterknife 示例

    eclipse的jar包配置 eclipse中貌似用不了butterknife buildToolsVersion "23.0.2" defaultConfig { applica ...

  5. Xcode 的正确打开方式——Debugging(转)

    转自CocoaChina http://www.cocoachina.com/ios/20150225/11190.html 程序员日常开发中有大量时间都会花费在 debug 上,从事 iOS 开发不 ...

  6. 网络流转换为Byte数组

    /// <summary> /// 网络流转换为Byte数组 /// </summary> /// <param name="stream">& ...

  7. java_reflect_02

    按我们所知道的.对于类中的method,constructor,field如果访问属性是private的情况下我们是访问不了的,但通过反射就可以做到 仔细分析api发现Method,Construct ...

  8. Windows Phone中使用Storyboard做类似 IOS 屏幕小白点的效果

    windows phone中做动画其实很方便的,可以使用Blend拖来拖去就做出一个简单的动画,下面做了一个 ios屏幕小白点的拖动效果,包括速度判断移动 使用Blend生成以下代码 <Stor ...

  9. 自己总结python用xlrd\xlwt读写excel

    1.首先安装xlrd\xlwt模块 xlrd模块下载地址: https://pypi.python.org/pypi/xlrd xlwt模块下载地址: https://pypi.python.org/ ...

  10. c# sqlserver备份还原(转)

    WinForm c# 备份 还原 数据库 其实是个非常简单的问题,一个Form,一个Button,一个OpenFileDialog,一个SaveFileDialog.下面给出备份与还原类 using ...