IMPLEMENTED IN PYTHON +1 | CART生成树
Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step . 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:

Code:
"""
Created on Thu Jan 30 15:36:39 2014 @filename: test.py
""" import cart c = cart.Cart()
c.trainDecisionTree('decision_tree_text.txt')
print c.trainresult
view test.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 18:05:22 2014 @filename: cart.py
"""
FILENAME = 'decision_tree_text.txt'
MAXDEPTH = 10 import numpy as np
import plottree class Cart():
def __init__(self):
self.trainresult = 'WARNING : please trainDecisionTree first!'
pass def trainDecisionTree(self, filename):
self.__loadDataSet(filename)
self.__optimalTree(self.__datamat) def __loadDataSet(self, filename):
fread = open(filename)
self.__dataset = np.array([row.strip().split('\t') \
for row in fread.readlines()])
self.__textdic = {}
for col in self.__dataset.T:
i = .0
for cell in col:
if not self.__textdic.has_key(cell):
self.__textdic[cell] = i
i += 1
self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
for cell in row]) \
for row in self.__dataset]) def __getSampleCount(self, setd, col = -1, s = None):
dic = {} if s is not None:
newset = self.__getSampleMat(setd,col,s)[:,-1]
else:
newset = setd[:,col] for cell in newset:
if not dic.has_key(cell):
dic[cell] = 1.
else:
dic[cell] += 1
return dic def __getSampleMat(self, setd, col, s):
lista = []; listb = []
for row in setd:
if row[col] == s:
lista.append(row)
else:
listb.append(row)
return np.array(lista), np.array(listb) def __getGiniD(self, setd):
sample_count = self.__getSampleCount(setd)
gini = 0
for item in sample_count.items():
gini += item[1]/len(setd) * (1- item[1]/len(setd))
return gini def __getGiniDA(self, setd, a):
sample_count = self.__getSampleCount(setd, a)
dic = {}
for item in sample_count.items():
setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
(1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
dic[item[0]]=gini
return min(dic.items()), dic def __optimalNode(self, setd):
coln = 0
ginicol = 0
mingini = {1:1}
for col in setd[:,:-1].T:
gini, dic = self.__getGiniDA(setd, coln)
if gini[1] < mingini[1]:
mingini = gini
ginicol = coln
coln += 1
return ginicol, mingini[0], mingini[1] def __optimalNodeText(self, col, value):
row = 0
tex = None
for cell in self.__dataset.T[col]:
if self.__datamat[row,col] == value:
tex = cell
break
row += 1
return tex def __optimalTree(self, setd):
arr = setd
count = MAXDEPTH-1
features = np.array(range(len(arr.T)))
lst = []
defaultc = None
while count > 0:
count -= 1
ginicol, value, gini = self.__optimalNode(arr)
parts = self.__getSampleMat(arr, ginicol, value)
args = [np.unique(part[:,-1]) for part in parts]
realvalues = [np.unique(part[:,ginicol])[0] for part in parts]
realcol = features[ginicol]
features = np.delete(features, ginicol)
if gini == 0 or len(arr.T) == 2:
if args[0] == defaultc:
value = realvalues[0]
else:
value = realvalues[1]
self.trainresult = self.__buildList(lst, realcol, value, gini)
self.__plotTree(self.trainresult)
return
if len(args[0]) == 1:
defaultc = args[0]
self.__buildList(lst, realcol, realvalues[0], gini)
arr = np.concatenate((parts[1][:,:ginicol], \
parts[1][:,ginicol+1:]), axis=1)
else:
defaultc = args[1]
self.__buildList(lst, realcol, realvalues[1], gini)
arr = np.concatenate((parts[0][:,:ginicol], \
parts[0][:,ginicol+1:]), axis=1) def __plotTree(self, lst):
dic = {}
for item in lst:
if dic == {}:
dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
else:
dic = {item[0]:{item[1]:'c1','ELSE':dic}}
tree = plottree.retrieveTree(dic)
self.trainresult = tree
plottree.createPlot(tree) def __buildList(self, lst, col, value, gini):
print 'feature col:', col, \
' feature val:', self.__optimalNodeText(col, value), \
' Gini:', gini, '\n'
lst.insert(0,[col,str(self.__optimalNodeText(col, \
value))+':'+str(value)])
return lst if __name__ == '__main__':
cart = Cart()
view cart.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 11:45:18 2014 @filename: plottree.py
""" import matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "1.0")
arrow_args = dict(arrowstyle = "<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va= "center",\
ha = "center", bbox = nodeType, arrowprops = arrow_args) def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ is 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \
{'yes': 'c1','no': 'c2'}}}}):
return dic def plotMidText(centrPt, parentPt, txtString):
xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = myTree.keys()[0]
centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff]
plotMidText(centrPt, parentPt, nodeTxt)
plotNode(firstStr, centrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
centrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
myTree = retrieveTree()
createPlot(myTree)
view plottree.py
输入数据

输出结果
feature col: 2 feature val: 是 Gini: 0.266666666667 feature col: 1 feature val: 是 Gini: 0.0

Reference:
Harrington P. Machine Learning in Action
李航. 统计学习方法
IMPLEMENTED IN PYTHON +1 | CART生成树的更多相关文章
- Python实现CART(基尼指数)
Python实现CART(基尼指数) 运行环境 Pyhton3 treePlotter模块(画图所需,不画图可不必) matplotlib(如果使用上面的模块必须) 计算过程 st=>start ...
- 机器学习之分类回归树(python实现CART)
之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...
- Algorithm: quick sort implemented in python 算法导论 快速排序
import random def partition(A, lo, hi): pivot_index = random.randint(lo, hi) pivot = A[pivot_index] ...
- leetcode-happy number implemented in python
视频分析: http://v.youku.com/v_show/id_XMTMyODkyNDA0MA==.html?from=y1.7-1.2 class Solution(object): def ...
- Awesome Python
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python开源框架、库、软件和资源大集合
A curated list of awesome Python frameworks, libraries, software and resources. Inspired by awesome- ...
- Python 库汇总英文版
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python框架、库以及软件资源汇总
转自:http://developer.51cto.com/art/201507/483510.htm 很多来自世界各地的程序员不求回报的写代码为别人造轮子.贡献代码.开发框架.开放源代码使得分散在世 ...
- Python Scopes and Namespaces
Before introducing classes, I first have to tell you something about Python's scope rules. Class def ...
随机推荐
- 什么是AAC音频格式 AAC-LC 和 AAC-HE的区别是什么
Advanced Audio Coding(高级音频解码),是一种由MPEG-4标准定义的有损音频压缩格式,由Fraunhofer发展,Dolby, Sony和AT&T是主要的贡献者. 在使用 ...
- 转载:C#中事件和委托的编译代码
接上文转载:C#中事件的由来,这时候,我们注释掉编译错误的行,然后重新进行编译,再借助Reflactor来对 event的声明语句做一探究,看看为什么会发生这样的错误: public event Gr ...
- css sprint 生成工具 bg2css
今天需要改个css sprint,之前使用过一个工具蛮好使,但是就是想不起叫什么名字,网上搜了很久,才再次找到,原来是bg2css,今天记录下,以备以后不时之需. 下载地址:http://www.cs ...
- 【原】AVAudio录制,播放 (解决真机播放音量太小)
原文链接:http://www.cnblogs.com/A--G/p/4624526.html 最近学习AVFoundation里的audio操作,最基本的录制和播放,参考了一个Code4pp的 一个 ...
- Linux shell日常命令和技巧
转自:http://www.vaikan.com/linux-shell-tips-and-tricks/ 原文:http://www.techbar.me/linux-shell-tips/ 使用L ...
- SSM框架入门和搭建 十部曲
又快到毕业设计的时候了,有的学弟说想用ssm做毕业设计,在网上找到资料看不懂,基础差.我就帮他写了一个demo,顺便也整理一下. SSM框架,顾名思义,就是Spring+SpringMVC+mybat ...
- Qt中绘图坐标QPainter,Viewport与Window的关系
在Qt中常常要自己重载一些paintEvent函数,这个时候往往忽略了两个很关键的API,那就是setViewport和setWindow. Viewport,顾名思义,反应的是物理坐标,就是你实际想 ...
- 11_RHEL安装Maya2015
1. 解压 tar -xvf ./Autodesk_Maya_English_2015_Linux_64bit.tgz 2. 运行 ./setup 2.1补充 如果提示缺少 libpng12.so.0 ...
- 初涉JavaScript模式 (2) : 基本技巧
尽量少用全局变量 大量使用全局变量会导致的后果 全局变量创建以后会在整个JavaScript应用和Web页面中共享.所有的全局变量都存在于一个全局命名空间内,很容易发生冲突 不知不觉创建了全局变量 其 ...
- JS笔记2 --定义对象
16.javascript中定义对象的几种方式(javascript中没有类的概念,只有对象): 1)基于已有对象扩充其属性和方法: var object = new Object(); object ...