IMPLEMENTED IN PYTHON +1 | CART生成树
Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step . 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:
Code:
"""
Created on Thu Jan 30 15:36:39 2014 @filename: test.py
""" import cart c = cart.Cart()
c.trainDecisionTree('decision_tree_text.txt')
print c.trainresult
view test.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 18:05:22 2014 @filename: cart.py
"""
FILENAME = 'decision_tree_text.txt'
MAXDEPTH = 10 import numpy as np
import plottree class Cart():
def __init__(self):
self.trainresult = 'WARNING : please trainDecisionTree first!'
pass def trainDecisionTree(self, filename):
self.__loadDataSet(filename)
self.__optimalTree(self.__datamat) def __loadDataSet(self, filename):
fread = open(filename)
self.__dataset = np.array([row.strip().split('\t') \
for row in fread.readlines()])
self.__textdic = {}
for col in self.__dataset.T:
i = .0
for cell in col:
if not self.__textdic.has_key(cell):
self.__textdic[cell] = i
i += 1
self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
for cell in row]) \
for row in self.__dataset]) def __getSampleCount(self, setd, col = -1, s = None):
dic = {} if s is not None:
newset = self.__getSampleMat(setd,col,s)[:,-1]
else:
newset = setd[:,col] for cell in newset:
if not dic.has_key(cell):
dic[cell] = 1.
else:
dic[cell] += 1
return dic def __getSampleMat(self, setd, col, s):
lista = []; listb = []
for row in setd:
if row[col] == s:
lista.append(row)
else:
listb.append(row)
return np.array(lista), np.array(listb) def __getGiniD(self, setd):
sample_count = self.__getSampleCount(setd)
gini = 0
for item in sample_count.items():
gini += item[1]/len(setd) * (1- item[1]/len(setd))
return gini def __getGiniDA(self, setd, a):
sample_count = self.__getSampleCount(setd, a)
dic = {}
for item in sample_count.items():
setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
(1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
dic[item[0]]=gini
return min(dic.items()), dic def __optimalNode(self, setd):
coln = 0
ginicol = 0
mingini = {1:1}
for col in setd[:,:-1].T:
gini, dic = self.__getGiniDA(setd, coln)
if gini[1] < mingini[1]:
mingini = gini
ginicol = coln
coln += 1
return ginicol, mingini[0], mingini[1] def __optimalNodeText(self, col, value):
row = 0
tex = None
for cell in self.__dataset.T[col]:
if self.__datamat[row,col] == value:
tex = cell
break
row += 1
return tex def __optimalTree(self, setd):
arr = setd
count = MAXDEPTH-1
features = np.array(range(len(arr.T)))
lst = []
defaultc = None
while count > 0:
count -= 1
ginicol, value, gini = self.__optimalNode(arr)
parts = self.__getSampleMat(arr, ginicol, value)
args = [np.unique(part[:,-1]) for part in parts]
realvalues = [np.unique(part[:,ginicol])[0] for part in parts]
realcol = features[ginicol]
features = np.delete(features, ginicol)
if gini == 0 or len(arr.T) == 2:
if args[0] == defaultc:
value = realvalues[0]
else:
value = realvalues[1]
self.trainresult = self.__buildList(lst, realcol, value, gini)
self.__plotTree(self.trainresult)
return
if len(args[0]) == 1:
defaultc = args[0]
self.__buildList(lst, realcol, realvalues[0], gini)
arr = np.concatenate((parts[1][:,:ginicol], \
parts[1][:,ginicol+1:]), axis=1)
else:
defaultc = args[1]
self.__buildList(lst, realcol, realvalues[1], gini)
arr = np.concatenate((parts[0][:,:ginicol], \
parts[0][:,ginicol+1:]), axis=1) def __plotTree(self, lst):
dic = {}
for item in lst:
if dic == {}:
dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
else:
dic = {item[0]:{item[1]:'c1','ELSE':dic}}
tree = plottree.retrieveTree(dic)
self.trainresult = tree
plottree.createPlot(tree) def __buildList(self, lst, col, value, gini):
print 'feature col:', col, \
' feature val:', self.__optimalNodeText(col, value), \
' Gini:', gini, '\n'
lst.insert(0,[col,str(self.__optimalNodeText(col, \
value))+':'+str(value)])
return lst if __name__ == '__main__':
cart = Cart()
view cart.py
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 29 11:45:18 2014 @filename: plottree.py
""" import matplotlib.pyplot as plt decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "1.0")
arrow_args = dict(arrowstyle = "<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
xycoords = 'axes fraction', xytext = centerPt, \
textcoords = 'axes fraction', va= "center",\
ha = "center", bbox = nodeType, arrowprops = arrow_args) def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ is 'dict':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs += 1
return numLeafs def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \
{'yes': 'c1','no': 'c2'}}}}):
return dic def plotMidText(centrPt, parentPt, txtString):
xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = myTree.keys()[0]
centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff]
plotMidText(centrPt, parentPt, nodeTxt)
plotNode(firstStr, centrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
centrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
myTree = retrieveTree()
createPlot(myTree)
view plottree.py
输入数据
输出结果
feature col: 2 feature val: 是 Gini: 0.266666666667 feature col: 1 feature val: 是 Gini: 0.0
Reference:
Harrington P. Machine Learning in Action
李航. 统计学习方法
IMPLEMENTED IN PYTHON +1 | CART生成树的更多相关文章
- Python实现CART(基尼指数)
Python实现CART(基尼指数) 运行环境 Pyhton3 treePlotter模块(画图所需,不画图可不必) matplotlib(如果使用上面的模块必须) 计算过程 st=>start ...
- 机器学习之分类回归树(python实现CART)
之前有文章介绍过决策树(ID3).简单回顾一下:ID3每次选取最佳特征来分割数据,这个最佳特征的判断原则是通过信息增益来实现的.按照某种特征切分数据后,该特征在以后切分数据集时就不再使用,因此存在切分 ...
- Algorithm: quick sort implemented in python 算法导论 快速排序
import random def partition(A, lo, hi): pivot_index = random.randint(lo, hi) pivot = A[pivot_index] ...
- leetcode-happy number implemented in python
视频分析: http://v.youku.com/v_show/id_XMTMyODkyNDA0MA==.html?from=y1.7-1.2 class Solution(object): def ...
- Awesome Python
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python开源框架、库、软件和资源大集合
A curated list of awesome Python frameworks, libraries, software and resources. Inspired by awesome- ...
- Python 库汇总英文版
Awesome Python A curated list of awesome Python frameworks, libraries, software and resources. Insp ...
- Python框架、库以及软件资源汇总
转自:http://developer.51cto.com/art/201507/483510.htm 很多来自世界各地的程序员不求回报的写代码为别人造轮子.贡献代码.开发框架.开放源代码使得分散在世 ...
- Python Scopes and Namespaces
Before introducing classes, I first have to tell you something about Python's scope rules. Class def ...
随机推荐
- AsyncTask 解析
[转载自 http://blog.csdn.net/yanbober ] 1 背景 Android异步处理机制一直都是Android的一个核心,也是应用工程师面试的一个知识点.前面我们分析了Handl ...
- 零基础学习云计算及大数据DBA集群架构师【企业级运维技术及实践项目2015年1月29日周五】
LNMP/LEMP项目搭建 { 项目框架 # Linux_____WEB_____PHP_____DB # rhel7_____apache__-(libphp5.so)-__php__-(php-m ...
- Android Xutils 框架
XUtils是git上比较活跃 功能比较完善的一个框架,是基于afinal开发的,比afinal稳定性提高了不少,下面是介绍: 鉴于大家的热情,我又写了一篇Android 最火框架XUtils之注解机 ...
- jQuery中在当前页面弹出一个新的界面
W.$.dialog({ content:'url:wswgrkbillController.do?snh&id='+b+'&bh='+c+'&ck='+d+'&sl= ...
- JavaBean学习--练习示例
初识Javabean,没感觉这鸟东西有什么好用的,一定是我太笨了 自己用jsp测试了下,这里用application作用域做个示例 <%@ page language="java&qu ...
- php5.2通过saprfc扩展远程连接sap730成功案例
公司刚上sap系统,由于资金有限,sap与其它系统的数据交换需要公司内部实现.于是,领导决定入库申请流程需要在sap与OA系统里实现电子签核流,重担果然落到我的身上.好在我只负责OA,还一位同事负责s ...
- javascript innerHTML、outerHTML、innerText、outerText的区别
1.功能讲解: innerHTML 设置或获取位于对象起始和结束标签内的 HTML outerHTML 设置或获取对象及其内容的 HTML 形式 innerText 设置或获取位于对象起始和结束标签内 ...
- sql语句分页代码
SET ANSI_NULLS ON GO SET QUOTED_IDENTIFIER ON GO alter proc sp_SelectInfomationByKeyWord--创建一个存储过程 - ...
- java中集合类的简介
结构 collection(接口) List(接口) LinkedList(类) ArrayList(类) Vector(类) Stack(类) Set(接口) Map(接口) Hashtable(类 ...
- python自学笔记
python自学笔记 python自学笔记 1.输出 2.输入 3.零碎 4.数据结构 4.1 list 类比于java中的数组 4.2 tuple 元祖 5.条件判断和循环 5.1 条件判断 5.2 ...