IMPLEMENTED IN PYTHON +1 | CART生成树
Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step . 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:

Code:
"""
Created on Thu Jan 30 15:36:39 2014 @filename: test.py
""" import cart c = cart.Cart()
c.trainDecisionTree('decision_tree_text.txt')
print c.trainresult
view test.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 18:05:22 2014 @filename: cart.py
"""
FILENAME = 'decision_tree_text.txt'
MAXDEPTH = 10 import numpy as np
import plottree class Cart():
def __init__(self):
self.trainresult = 'WARNING : please trainDecisionTree first!'
pass def trainDecisionTree(self, filename):
self.__loadDataSet(filename)
self.__optimalTree(self.__datamat) def __loadDataSet(self, filename):
fread = open(filename)
self.__dataset = np.array([row.strip().split('\t') \
for row in fread.readlines()])
self.__textdic = {}
for col in self.__dataset.T:
i = .0
for cell in col:
if not self.__textdic.has_key(cell):
self.__textdic[cell] = i
i += 1
self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
for cell in row]) \
for row in self.__dataset]) def __getSampleCount(self, setd, col = -1, s = None):
dic = {} if s is not None:
newset = self.__getSampleMat(setd,col,s)[:,-1]
else:
newset = setd[:,col] for cell in newset:
if not dic.has_key(cell):
dic[cell] = 1.
else:
dic[cell] += 1
return dic def __getSampleMat(self, setd, col, s):
lista = []; listb = []
for row in setd:
if row[col] == s:
lista.append(row)
else:
listb.append(row)
return np.array(lista), np.array(listb) def __getGiniD(self, setd):
sample_count = self.__getSampleCount(setd)
gini = 0
for item in sample_count.items():
gini += item[1]/len(setd) * (1- item[1]/len(setd))
return gini def __getGiniDA(self, setd, a):
sample_count = self.__getSampleCount(setd, a)
dic = {}
for item in sample_count.items():
setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
(1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
dic[item[0]]=gini
return min(dic.items()), dic def __optimalNode(self, setd):
coln = 0
ginicol = 0
mingini = {1:1}
for col in setd[:,:-1].T:
gini, dic = self.__getGiniDA(setd, coln)
if gini[1] < mingini[1]:
mingini = gini
ginicol = coln
coln += 1
return ginicol, mingini[0], mingini[1] def __optimalNodeText(self, col, value):
row = 0
tex = None
for cell in self.__dataset.T[col]:
if self.__datamat[row,col] == value:
tex = cell
break
row += 1
return tex def __optimalTree(self, setd):
arr = setd
count = MAXDEPTH-1
features = np.array(range(len(arr.T)))
lst = []
defaultc = None
while count > 0:
count -= 1
ginicol, value, gini = self.__optimalNode(arr)
parts = self.__getSampleMat(arr, ginicol, value)
args = [np.unique(part[:,-1]) for part in parts]
realvalues = [np.unique(part[:,ginicol])[0] for part in parts]
realcol = features[ginicol]
features = np.delete(features, ginicol)
if gini == 0 or len(arr.T) == 2:
if args[0] == defaultc:
value = realvalues[0]
else:
value = realvalues[1]
self.trainresult = self.__buildList(lst, realcol, value, gini)
self.__plotTree(self.trainresult)
return
if len(args[0]) == 1:
defaultc = args[0]
self.__buildList(lst, realcol, realvalues[0], gini)
arr = np.concatenate((parts[1][:,:ginicol], \
parts[1][:,ginicol+1:]), axis=1)
else:
defaultc = args[1]
self.__buildList(lst, realcol, realvalues[1], gini)
arr = np.concatenate((parts[0][:,:ginicol], \
parts[0][:,ginicol+1:]), axis=1) def __plotTree(self, lst):
dic = {}
for item in lst:
if dic == {}:
dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
else:
dic = {item[0]:{item[1]:'c1','ELSE':dic}}
tree = plottree.retrieveTree(dic)
self.trainresult = tree
plottree.createPlot(tree) def __buildList(self, lst, col, value, gini):
print 'feature col:', col, \
' feature val:', self.__optimalNodeText(col, value), \
' Gini:', gini, '\n'
lst.insert(0,[col,str(self.__optimalNodeText(col, \
value))+':'+str(value)])
return lst if __name__ == '__main__':
cart = Cart()
view cart.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 11:45:18 2014 @filename: plottree.py
""" import matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "1.0")
arrow_args = dict(arrowstyle = "<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va= "center",\
ha = "center", bbox = nodeType, arrowprops = arrow_args) def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ is 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \
{'yes': 'c1','no': 'c2'}}}}):
return dic def plotMidText(centrPt, parentPt, txtString):
xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = myTree.keys()[0]
centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff]
plotMidText(centrPt, parentPt, nodeTxt)
plotNode(firstStr, centrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
centrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
myTree = retrieveTree()
createPlot(myTree)
view plottree.py
输入数据

输出结果
feature col: 2 feature val: 是 Gini: 0.266666666667 feature col: 1 feature val: 是 Gini: 0.0

Reference:
Harrington P. Machine Learning in Action
李航. 统计学习方法
IMPLEMENTED IN PYTHON +1 | CART生成树的更多相关文章
- Python实现CART(基尼指数)
Python实现CART(基尼指数) 运行环境 Pyhton3 treePlotter模块(画图所需,不画图可不必) matplotlib(如果使用上面的模块必须) 计算过程 st=>start ...
- 机器学习之分类回归树(python实现CART)
之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...
- Algorithm: quick sort implemented in python 算法导论 快速排序
import random def partition(A, lo, hi): pivot_index = random.randint(lo, hi) pivot = A[pivot_index] ...
- leetcode-happy number implemented in python
视频分析: http://v.youku.com/v_show/id_XMTMyODkyNDA0MA==.html?from=y1.7-1.2 class Solution(object): def ...
- Awesome Python
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python开源框架、库、软件和资源大集合
A curated list of awesome Python frameworks, libraries, software and resources. Inspired by awesome- ...
- Python 库汇总英文版
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python框架、库以及软件资源汇总
转自:http://developer.51cto.com/art/201507/483510.htm 很多来自世界各地的程序员不求回报的写代码为别人造轮子.贡献代码.开发框架.开放源代码使得分散在世 ...
- Python Scopes and Namespaces
Before introducing classes, I first have to tell you something about Python's scope rules. Class def ...
随机推荐
- ubuntu 连接 mssql <转>
转自 http://www.sendong.net/thread-90941-1-1.html 在linux下连接MSSQL,因为微软同志没有提供任何接口给开发人员,大约他们认为要用MSSQL的,只 ...
- 2015 UESTC Winter Training #10【Northeastern Europe 2009】
2015 UESTC Winter Training #10 Northeastern Europe 2009 最近集训都不在状态啊,嘛,上午一直在练车,比赛时也是刚吃过午饭,状态不好也难免,下次比赛 ...
- HDU 2639 (01背包第k优解)
/* 01背包第k优解问题 f[i][j][k] 前i个物品体积为j的第k优解 对于每次的ij状态 记下之前的两种状态 i-1 j-w[i] (选i) i-1 j (不选i) 分别k个 然后归并排序并 ...
- DataSet离线数据集实例
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.W ...
- OD: Peimei & Versioning Analysis
PE 中漫步—“白眉” 指令追踪技术与 Paimei 程序异常发生的位置通常离漏洞函数很远,当溢出发生时,栈帧往往也会遭到破坏,给动态调试制造很大的困难. 指令追踪最大限度地结合了动态分析和静态分析的 ...
- 谷歌postman插件用不了的命令行指令
谷歌postman插件用不了,想测试通过post提交传过来的参数测试,打开所在目录,shift右键,打开命令窗口,输入一下指令,即可获取到提交的值curl -X POST -d 'name=kpf&a ...
- Sqlite 错误码
#define SQLITE_OK 0 /* 成功 | Successful result */ /* 错误码开始 */ #define SQLITE_ERROR 1 /* SQL错误 或 丢失数据库 ...
- JS屏蔽右键菜单,复制,粘帖xxxxx........
//屏蔽右键菜单 document.oncontextmenu = function (event) { if (window.event) { event = window.event; } try ...
- [C++] namespace相关语法
本段测试代码包括如下内容: (1) 如何访问namespace中声明的名称:(2) namespace导致的相关冲突:(3) namespace可嵌套:(4) 可以在namespace中使用using ...
- 【USACO 2.3.3】零数列
[题目描述] 请考虑一个由1到N(N=3, 4, 5 ... 9)的数字组成的递增数列:1 2 3 ... N. 现在请在数列中插入“+”表示加,或者“-”表示减,“ ”表示空白(例如1-2 3就等于 ...