本文来自《机器学习实战》(Peter Harrington)第九章“树回归”部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐clone仓库后run cell all就可以了。

github地址:https://github.com/gshtime/machinelearning-in-action-python3

转载请标明原文链接

1 原理

CART(Classification and Regression Trees,分类回归树)是决策树算法的一种,这种树构建算法既可以用于分类也可以用于回归。

它采用一种递归二元分割(recursive binary splitting)的技术,分割方法采用基于最小距离的基尼指数(分类树中)或最小平方残差(回归树中)等方法来估计函数的不纯度,从而将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

因此,CART的目标是:选择输入变量和那些变量上的分割点,直到创建出适当的树。在这个过程中,使用贪婪算法(greedy algorithm)选择使用哪个输入变量和分割点,以使成本函数(cost function)最小化。

1.1 CART回归树的原理

本文主要讲解CART回归树的原理及实现

现在关注一下回归树的 CART 算法的细节。简要来说,创建一个决策树包含两步:

  1. 把预测器空间,即一系列可能值 \(X_1,X_2,...,X_p\) 分成 \(J\) 个不同的且非重叠的区域 \(R_1,R_2,...,R_J\)。

  2. 对进入区域 \(R_J\) 的每一个样本观测值都进行相同的预测,该预测就是 \(R_J\) 中训练样本预测值的均值。

为了创建 \(J\) 个区域 \(R_1,R_2,...,R_J\),预测器区域被分为高维度的矩形或盒形。其目的在于通过下列式子找到能够使 \(RSS\) 最小化的盒形区域 \(R_1,R_2,...,R_J\),

\[\sum_{j=1}^{J} \sum_{i \in R_j} \big(y_i - \hat{y}_{R_j}\big)^2
\]

其中,\(\hat{y}_{R_j}\) 即是第 \(j\) 个盒形中训练观测的平均预测值。

鉴于这种空间分割在计算上是不可行的,因此我们常使用贪婪方法(greedy approach)来划分区域,叫做递归二元分割(recursive binary splitting)。

它是贪婪的(greedy),这是因为在创建树过程中的每一步骤,最佳分割都会在每个特定步骤选定,而不是对未来进行预测,并选取一个将会在未来步骤中出现且有助于创建更好的树的分割。注意所有的划分区域 \(R_j,∀j∈[1,J]\) 都是矩形。为了进行递归二元分割,首先选取预测器 \(X_j\) (即数据集中的一个特征)和切割点 \(s\)(即该特征下某一个数据的值),递归遍历该特征下面所有的值作为二元分割的切割点,对预测器(特征)下的数据分割到不同的区域,即:\(R_1(j,s)=\big\{ X|Xj < s \big\} 和 R_2(j,s)=\big\{ X|Xj \ge s \big\}\),使得代价函数RSS得到最大程度的下降。从数学上讲,就是要寻找区域数J(我理解为叶节点数量)和分割点s,使分割后的代价函数最小化:

​​

\[\sum_{i: x_i \in R_1(j,s)} \big(y_i-\hat{y}_{R_1}\big)^2 + \sum_{i: x_i \in R_2(j,s)} \big(y_i-\hat{y}_{R_2}\big)^2
\]

其中 \(\hat{y}_{R_1}\) 为区域 \(R_1(j,s)\) 中观察样本的平均预测值,\(\hat{y}_{R_2}\) 为区域 \(R_2(j,s)\) 的观察样本预测均值。这一过程不断重复以搜寻最好的预测器和切分点,并进一步分隔数据以使每一个子区域内的 RSS 最小化。然而,我们不会分割整个预测器空间,我们只会分割一个或两个前面已经认定的区域。这一过程会一直持续,直到达到停止准则,例如我们可以设定停止准则为每一个区域最多包含 m 个观察样本。一旦我们创建了区域 \(R_1、R_2、...、R_J\),给定一个测试样本,我们就可以用该区域所有训练样本的平均预测值来预测该测试样本的值。

2 代码

2.1 CART回归树实现

代码比较长,不知道cnblogs中是否能折叠,为了方便复制,还是都放在一块吧,github中的代码是分开的,有需要可以去看。

原书regTrees.py部分的代码如下

# -*- coding: utf-8 -*-
import numpy as np def loadDataSet(fileName):
'''
read the data file using TAB as separator,and store the data in float list
'''
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0, mat1 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
feat, val =chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree def regLeaf(dataSet):
return np.mean(dataSet[:, -1]) def regErr(dataSet):
return np.var(dataSet[:,-1]) * np.shape(dataSet)[0] #choose the best feature and splitting value
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0] #tolerant value of S decilne
tolN = ops[1] #min number of samples to be splitted
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = np.shape(dataSet)
S = errType(dataSet)
bestS = np.inf;
bestIndex= 0;
bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal)
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#verdict whether the deciline of S reach the tolS or not
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue def isTree(obj):
return (type(obj).__name__=='dict') def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']) : tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0 def prune(tree, testData):
if np.shape(testData)[0] == 0: return getMean(tree)
if(isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):tree['right']= prune(tree['right'],rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else: return tree
else: return tree def linearSolve(dataSet):
m,n = np.shape(dataSet)
X = np.mat(np.ones((m,n)))
Y = np.mat(np.ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]
Y = dataSet[:,-1]
xTx = X.T*X
if np.linalg.det(xTx) == 0.0:
raise NameError("This matrix is singular, cannot do inverse,\ntry increasing the second value of ops")
ws = xTx.I * (X.T*Y)
return ws, X , Y
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return np.sum(np.power(Y - yHat, 2)) def regTreeEval(model, inDat):
return float(model) def modelTreeEval(model, inDat):
n = np.shape(inDat)[1]
X = np.mat(np.ones((1,n+1)))
X[:,1:n+1] = inDat
return float(X*model) def treeForecast(tree, inData, modelEval=regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForecast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForecast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData) def createForecast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = np.mat(np.zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval)
return yHat

2.2 使用python3的tkinter库创建GUI

python 2to3

原书的代码使针对python2.x环境构建的,在python2.x中应该import Tkinter,而在python3.x中,应该import tkinter才能正常导入Tkinter库

代码

# -*- coding:utf-8 -*-
import tkinter as tk import matplotlib matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure def reDraw(tolS, tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111) if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
yHat = createForecast(myTree, reDraw.testDat, modelTreeEval)
else:
myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = createForecast(myTree, reDraw.testDat) reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) reDraw.canvas.show() def getInput():
try:
tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integet for tolN")
tolNentry.delete(0, tk.END)
tolNentry.insert(0, "10")
try:
tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Integet for tolS")
tolNentry.delete(0, tk.END)
tolNentry.insert(0, "1.0")
return tolN, tolS def drawNewTree():
tolN, tolS = getInput()
reDraw(tolS, tolN) root = tk.Tk() #tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3) reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3) tk.Label(root, text="tolN").grid(row=1, column=0)
tolNentry = tk.Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
tk.Label(root, text="tolS").grid(row=2, column=0)
tolSentry = tk.Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3) chkBtnVar = tk.IntVar()
chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2) reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))
reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01) reDraw(1.0, 10) root.mainloop()

测试代码

测试的代码都在书里,我的github仓库里也有,有空我再放这儿吧

注意

有时候运行tkinter的时候,可能python会无限地崩溃,可以试一下重装matplotlib库来解决

参考资料

  1. https://zhuanlan.zhihu.com/p/28217071

    这是一篇文章的中文翻译,推荐大家看看该文章的英文原文,这篇文章我觉得写得很棒,对了解CART有很大帮助,文中给出了借助sklearn库的CART实现方法,比较简单,另外作者给了其他决策树算法的文章链接。总之很推荐。
  2. http://blog.csdn.net/u014568921/article/details/45082197

写得比较仓促,自己也在理解和学习中,如果有不对的地方,还请多多指正。现在时间晚了,回头有空把这篇文章写得更全一点

【机器学习实战 第九章】树回归 CART算法的原理与实现 - python3的更多相关文章

  1. 《机器学习实战》学习笔记第九章 —— 决策树之CART算法

    相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...

  2. 【机器学习实战】第9章 树回归(Tree Regression)

    第9章 树回归 <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/ ...

  3. 机器学习实战 - 读书笔记(12) - 使用FP-growth算法来高效发现频繁项集

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第12章 - 使用FP-growth算法来高效发现频繁项集. 基本概念 FP-growt ...

  4. 机器学习实战 - 读书笔记(11) - 使用Apriori算法进行关联分析

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第11章 - 使用Apriori算法进行关联分析. 基本概念 关联分析(associat ...

  5. 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  6. 机器学习实战笔记5(logistic回归)

    1:简单概念描写叙述 如果如今有一些数据点,我们用一条直线对这些点进行拟合(改线称为最佳拟合直线),这个拟合过程就称为回归.训练分类器就是为了寻找最佳拟合參数,使用的是最优化算法. 基于sigmoid ...

  7. 《机器学习实战》-逻辑(Logistic)回归

    目录 Logistic 回归 本章内容 回归算法 Logistic 回归的一般过程 Logistic的优缺点 基于 Logistic 回归和 Sigmoid 函数的分类 Sigmoid 函数 Logi ...

  8. 机器学习实战读书笔记(二)k-近邻算法

    knn算法: 1.优点:精度高.对异常值不敏感.无数据输入假定 2.缺点:计算复杂度高.空间复杂度高. 3.适用数据范围:数值型和标称型. 一般流程: 1.收集数据 2.准备数据 3.分析数据 4.训 ...

  9. DirectX12 3D 游戏开发与实战第九章内容(上)

    仅供个人学习使用,请勿转载. 9.纹理贴图 学习目标: 学习如何将局部纹理映射到网格三角形上 探究如何创建和启用纹理 学会如何通过纹理过滤来创建更加平滑的图像 探索如何使用寻址模式来进行多次纹理贴图 ...

随机推荐

  1. Gym 101673F Keeping On Track

    原题传送门 题意:给定一颗n+1(n≤10000)个结点的树(当然有n条边辣),定义一个结点为critical,当且仅当除去这个点及与其相连的边后,所有不相连的结点对数量最多.现在保证有且仅有一个结点 ...

  2. 【Html5】-- 塔台管制

    想做这个游戏已久,今天终于初步完成,先解释下,这是一个模拟机场塔台管制指挥的游戏,飞机从不同的方向飞入管制空域,有不同的目的地,飞机名称最后一个字母表示飞机要到达的目的地,分ABCD和R.A-D表示四 ...

  3. 【JavaScript for循环实例】

    1.大马驮2石粮食,中马驮1石粮食,两头小马驮一石粮食,要用100匹马,驮100石粮食,该如何调配? //驮100石粮食,大马需要50匹 for(var a=0;a<=50;a++){ //驮1 ...

  4. linux下,文件的权限和数字对应关系详解

    命令 chmod ABC file 其中A.B.C各为一个数字,分别表示User.Group.及Other的权限. A.B.C这三个数字如果各自转换成由"0"."1&qu ...

  5. 重新学习一次javascript;

    每次有项目的时候,总觉得自己什么都不会做,然后做的时候又很简单,一会就做完了,啪啪打脸: 每次别人问的时候,我知道怎么做,但是不知道具体原理,觉得瞬间low了: 想要好好的吧基础掌握一下: 这几天空闲 ...

  6. TypeScript装饰器(decorators)

    装饰器是一种特殊类型的声明,它能够被附加到类声明,方法, 访问符,属性或参数上,可以修改类的行为. 装饰器使用 @expression这种形式,expression求值后必须为一个函数,它会在运行时被 ...

  7. UE4 unreliable 同步问题

    今天发现了一个问题,标记为unreliable的函数从来不执行,但是官方文档上的说明是只有在网络负载重时才不执行此类函数,哎哎哎.

  8. 读懂_countof,可以懂得什么

    在c++开发中数组是我们经常使用存储结构,而于此同时"数组越界"是每个c++程序员不能不提防陷阱. 还好,我们有预定义宏_countof. 一.在visual c++开发环境下,它 ...

  9. 从零开始学习前端开发 — 6、CSS布局模型

    一.css布局模型 1.流动模型(Flow) 元素在不设置css样式时的布局模型,是块元素就独占一行,是内联元素就在一行逐个进行显示 2.浮动模型(Float) 使用float属性来进行网页布局,给元 ...

  10. final、finally、finalize

    final是一个修饰词.可以修饰变量.方法.类 final修饰变量时分为两种 )1.修饰成员变量:该成员变量不可以被二次赋值.也就是说成员变量无法改变.且该成员变量要么在定义时初始化,要么在构造器中进 ...