本文来自《机器学习实战》(Peter Harrington)第九章“树回归”部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐clone仓库后run cell all就可以了。

github地址:https://github.com/gshtime/machinelearning-in-action-python3

转载请标明原文链接

1 原理

CART(Classification and Regression Trees,分类回归树)是决策树算法的一种,这种树构建算法既可以用于分类也可以用于回归。

它采用一种递归二元分割(recursive binary splitting)的技术,分割方法采用基于最小距离的基尼指数(分类树中)或最小平方残差(回归树中)等方法来估计函数的不纯度,从而将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

因此,CART的目标是:选择输入变量和那些变量上的分割点,直到创建出适当的树。在这个过程中,使用贪婪算法(greedy algorithm)选择使用哪个输入变量和分割点,以使成本函数(cost function)最小化。

1.1 CART回归树的原理

本文主要讲解CART回归树的原理及实现

现在关注一下回归树的 CART 算法的细节。简要来说,创建一个决策树包含两步:

  1. 把预测器空间,即一系列可能值 \(X_1,X_2,...,X_p\) 分成 \(J\) 个不同的且非重叠的区域 \(R_1,R_2,...,R_J\)。

  2. 对进入区域 \(R_J\) 的每一个样本观测值都进行相同的预测,该预测就是 \(R_J\) 中训练样本预测值的均值。

为了创建 \(J\) 个区域 \(R_1,R_2,...,R_J\),预测器区域被分为高维度的矩形或盒形。其目的在于通过下列式子找到能够使 \(RSS\) 最小化的盒形区域 \(R_1,R_2,...,R_J\),

\[\sum_{j=1}^{J} \sum_{i \in R_j} \big(y_i - \hat{y}_{R_j}\big)^2
\]

其中,\(\hat{y}_{R_j}\) 即是第 \(j\) 个盒形中训练观测的平均预测值。

鉴于这种空间分割在计算上是不可行的,因此我们常使用贪婪方法(greedy approach)来划分区域,叫做递归二元分割(recursive binary splitting)。

它是贪婪的(greedy),这是因为在创建树过程中的每一步骤,最佳分割都会在每个特定步骤选定,而不是对未来进行预测,并选取一个将会在未来步骤中出现且有助于创建更好的树的分割。注意所有的划分区域 \(R_j,∀j∈[1,J]\) 都是矩形。为了进行递归二元分割,首先选取预测器 \(X_j\) (即数据集中的一个特征)和切割点 \(s\)(即该特征下某一个数据的值),递归遍历该特征下面所有的值作为二元分割的切割点,对预测器(特征)下的数据分割到不同的区域,即:\(R_1(j,s)=\big\{ X|Xj < s \big\} 和 R_2(j,s)=\big\{ X|Xj \ge s \big\}\),使得代价函数RSS得到最大程度的下降。从数学上讲,就是要寻找区域数J(我理解为叶节点数量)和分割点s,使分割后的代价函数最小化:

​​

\[\sum_{i: x_i \in R_1(j,s)} \big(y_i-\hat{y}_{R_1}\big)^2 + \sum_{i: x_i \in R_2(j,s)} \big(y_i-\hat{y}_{R_2}\big)^2
\]

其中 \(\hat{y}_{R_1}\) 为区域 \(R_1(j,s)\) 中观察样本的平均预测值,\(\hat{y}_{R_2}\) 为区域 \(R_2(j,s)\) 的观察样本预测均值。这一过程不断重复以搜寻最好的预测器和切分点,并进一步分隔数据以使每一个子区域内的 RSS 最小化。然而,我们不会分割整个预测器空间,我们只会分割一个或两个前面已经认定的区域。这一过程会一直持续,直到达到停止准则,例如我们可以设定停止准则为每一个区域最多包含 m 个观察样本。一旦我们创建了区域 \(R_1、R_2、...、R_J\),给定一个测试样本,我们就可以用该区域所有训练样本的平均预测值来预测该测试样本的值。

2 代码

2.1 CART回归树实现

代码比较长,不知道cnblogs中是否能折叠,为了方便复制,还是都放在一块吧,github中的代码是分开的,有需要可以去看。

原书regTrees.py部分的代码如下

# -*- coding: utf-8 -*-
import numpy as np def loadDataSet(fileName):
'''
read the data file using TAB as separator,and store the data in float list
'''
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0, mat1 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
feat, val =chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree def regLeaf(dataSet):
return np.mean(dataSet[:, -1]) def regErr(dataSet):
return np.var(dataSet[:,-1]) * np.shape(dataSet)[0] #choose the best feature and splitting value
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0] #tolerant value of S decilne
tolN = ops[1] #min number of samples to be splitted
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = np.shape(dataSet)
S = errType(dataSet)
bestS = np.inf;
bestIndex= 0;
bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal)
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#verdict whether the deciline of S reach the tolS or not
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue def isTree(obj):
return (type(obj).__name__=='dict') def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']) : tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0 def prune(tree, testData):
if np.shape(testData)[0] == 0: return getMean(tree)
if(isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):tree['right']= prune(tree['right'],rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else: return tree
else: return tree def linearSolve(dataSet):
m,n = np.shape(dataSet)
X = np.mat(np.ones((m,n)))
Y = np.mat(np.ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]
Y = dataSet[:,-1]
xTx = X.T*X
if np.linalg.det(xTx) == 0.0:
raise NameError("This matrix is singular, cannot do inverse,\ntry increasing the second value of ops")
ws = xTx.I * (X.T*Y)
return ws, X , Y
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return np.sum(np.power(Y - yHat, 2)) def regTreeEval(model, inDat):
return float(model) def modelTreeEval(model, inDat):
n = np.shape(inDat)[1]
X = np.mat(np.ones((1,n+1)))
X[:,1:n+1] = inDat
return float(X*model) def treeForecast(tree, inData, modelEval=regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForecast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForecast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData) def createForecast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = np.mat(np.zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval)
return yHat

2.2 使用python3的tkinter库创建GUI

python 2to3

原书的代码使针对python2.x环境构建的,在python2.x中应该import Tkinter,而在python3.x中,应该import tkinter才能正常导入Tkinter库

代码

# -*- coding:utf-8 -*-
import tkinter as tk import matplotlib matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure def reDraw(tolS, tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111) if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
yHat = createForecast(myTree, reDraw.testDat, modelTreeEval)
else:
myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = createForecast(myTree, reDraw.testDat) reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) reDraw.canvas.show() def getInput():
try:
tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integet for tolN")
tolNentry.delete(0, tk.END)
tolNentry.insert(0, "10")
try:
tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Integet for tolS")
tolNentry.delete(0, tk.END)
tolNentry.insert(0, "1.0")
return tolN, tolS def drawNewTree():
tolN, tolS = getInput()
reDraw(tolS, tolN) root = tk.Tk() #tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3) reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3) tk.Label(root, text="tolN").grid(row=1, column=0)
tolNentry = tk.Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
tk.Label(root, text="tolS").grid(row=2, column=0)
tolSentry = tk.Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3) chkBtnVar = tk.IntVar()
chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2) reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))
reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01) reDraw(1.0, 10) root.mainloop()

测试代码

测试的代码都在书里,我的github仓库里也有,有空我再放这儿吧

注意

有时候运行tkinter的时候,可能python会无限地崩溃,可以试一下重装matplotlib库来解决

参考资料

  1. https://zhuanlan.zhihu.com/p/28217071

    这是一篇文章的中文翻译,推荐大家看看该文章的英文原文,这篇文章我觉得写得很棒,对了解CART有很大帮助,文中给出了借助sklearn库的CART实现方法,比较简单,另外作者给了其他决策树算法的文章链接。总之很推荐。
  2. http://blog.csdn.net/u014568921/article/details/45082197

写得比较仓促,自己也在理解和学习中,如果有不对的地方,还请多多指正。现在时间晚了,回头有空把这篇文章写得更全一点

【机器学习实战 第九章】树回归 CART算法的原理与实现 - python3的更多相关文章

  1. 《机器学习实战》学习笔记第九章 —— 决策树之CART算法

    相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...

  2. 【机器学习实战】第9章 树回归(Tree Regression)

    第9章 树回归 <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/ ...

  3. 机器学习实战 - 读书笔记(12) - 使用FP-growth算法来高效发现频繁项集

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第12章 - 使用FP-growth算法来高效发现频繁项集. 基本概念 FP-growt ...

  4. 机器学习实战 - 读书笔记(11) - 使用Apriori算法进行关联分析

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第11章 - 使用Apriori算法进行关联分析. 基本概念 关联分析(associat ...

  5. 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  6. 机器学习实战笔记5(logistic回归)

    1:简单概念描写叙述 如果如今有一些数据点,我们用一条直线对这些点进行拟合(改线称为最佳拟合直线),这个拟合过程就称为回归.训练分类器就是为了寻找最佳拟合參数,使用的是最优化算法. 基于sigmoid ...

  7. 《机器学习实战》-逻辑(Logistic)回归

    目录 Logistic 回归 本章内容 回归算法 Logistic 回归的一般过程 Logistic的优缺点 基于 Logistic 回归和 Sigmoid 函数的分类 Sigmoid 函数 Logi ...

  8. 机器学习实战读书笔记(二)k-近邻算法

    knn算法: 1.优点:精度高.对异常值不敏感.无数据输入假定 2.缺点:计算复杂度高.空间复杂度高. 3.适用数据范围:数值型和标称型. 一般流程: 1.收集数据 2.准备数据 3.分析数据 4.训 ...

  9. DirectX12 3D 游戏开发与实战第九章内容(上)

    仅供个人学习使用,请勿转载. 9.纹理贴图 学习目标: 学习如何将局部纹理映射到网格三角形上 探究如何创建和启用纹理 学会如何通过纹理过滤来创建更加平滑的图像 探索如何使用寻址模式来进行多次纹理贴图 ...

随机推荐

  1. JavaProblem之hashCode详解

    一.HashCode简介 1.1.什么是Hash和Hash表 要想清楚hashCode就要先清楚知道什么是Hash 1)Hash hash是一个函数,该函数中的实现就是一种算法,就是通过一系列的算法来 ...

  2. hihoCoder #1082 : 然而沼跃鱼早就看穿了一切(字符串处理)

    #1082 : 然而沼跃鱼早就看穿了一切 时间限制:1000ms 单点时限:1000ms 内存限制:256MB 描述 fjxmlhx每天都在被沼跃鱼刷屏,因此他急切的找到了你希望你写一个程序屏蔽所有句 ...

  3. [POJ2243]考研路茫茫——单词情结

    又是AC自动机上用矩乘优化DP= = 其实和上一题基本一样...补集转化思想.. 只是要多弄一个小矩阵求(26^1+26^2+....+26^L),并且也要求f的总和(因为是长度<=L) 直接调 ...

  4. 了解 Python 语言中的时间处理

    python 语言对于时间的处理继承了 C语言的传统,时间值是以秒为单位的浮点数,记录的是从1970年1月1日零点到现在的秒数,这个秒数可以转换成我们日常可阅读形式的日期和时间:我们下面首先来看一下p ...

  5. 在 .NET中,一种更方便操作配置项的方法

    在应用程序的开发过程中,我们往往会为软件提供一些配置项,以允许软件根据配置项灵活来做事情,比如配置日志文件路径等,此外,我们还可以用配置项来为用户存储其偏好设置等. .NET 为我们默认提供了配置机制 ...

  6. Deep Learning速成教程

          引言         深度学习,即Deep Learning,是一种学习算法(Learning algorithm),亦是人工智能领域的一个重要分支.从快速发展到实际应用,短短几年时间里, ...

  7. [图像类名词解释][ RGB YUV HSV相关解释说明]

    一.概述 颜色通常用三个独立的属性来描述,三个独立变量综合作用,自然就构成一个空间坐标,这就是颜色空间.但被描述的颜色对象本身是客观的,不同颜色空间只是从不同的角度去衡量同一个对象.颜色空间按照基本机 ...

  8. 久未更 ~ 三之 —— CardView简单记录

    > > > > > 久未更 系列一:CardView 点击涟漪效果实现 //在 cardview 中 实现点击涟漪效果 android:clickable="t ...

  9. ThinkPHP5+小程序商城 网盘视频

    ThinkPHP5+小程序商城   网盘视频  有需要联系我  QQ:1844912514

  10. Yourphp  使用说明

    https://wenku.baidu.com/view/c8d2e667cc1755270722088a.html 这个是站点的配置信息,比如:网站名称. LOGO .网站标题等 推荐位:个别可能用 ...