系列文章:《机器学习实战》学习笔记

本篇文章使用到的完整代码:Here

决策树

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
  • 缺点:可能会产生过度匹配问题。
  • 适用数据类型:离散型和连续型

\[QAQ
\]

我们经常使用决策树处理分类问题,它的过程类似二十个问题的游戏:参与游戏的一方在脑海里想某个事物,其他参与者向他提出问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小带猜测事物的范围。如图1所示的流程图就是一个决策树,长方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),表示已经得出结论,可以终止运行。从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或终止模块。

图1构造了一个假象的邮件分类系统,它首先检测发送邮件域名地址。如果地址为myEmployer.com,则将其放在分类"无聊时需要阅读的邮件"中。如果邮件不是来自这个域名,则检查内容是否包括单词曲棍球,如果包含则将邮件归类到"需要及时处理的朋友邮件",否则将邮件归类到"无须阅读的垃圾邮件"。

第2章介绍的k-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解。

本章构造的决策树算法能够读取数据集合,构建类似图1的决策树。决策树可以在数据集合中提取出一系列规则,规则创建的过程就是机器学习的过程。现在我们已经大致了解决策树可以完成哪些任务,接下来我们将学习如何从一堆原始数据中构造决策树。首先我们讨论构造决策树的方法,以及如何编写构造树的Python代码;接着提出一些度量算法成功率的方法;最后使用递归建立分类器。

一、决策树的构造

在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。我们假设已经根据一定的方法选取了待划分的特征,则原始数据集将根据这个特征被划分为几个数据子集。这数据子集会分布在决策点(关键特征)的所有分支上。如果某个分支下的数据属于同一类型,则无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类型,则需要递归地重复划分数据子集的过程,直到每个数据子集内的数据类型相同。

创建分支的过程用伪代码表示如下:

检测数据集中的每个子项是否属于同一类型:
  如果是,则返回类型标签
  否则:
    寻找划分数据集的最好特征
    划分数据集
    创建分支节点
    对划分的每个数据子集:
      递归调用本算法并添加返回结果到分支节点中
    返回分支节点

注:伪代码是一个递归函数。

决策树的一般流程:

  1. 收集数据:可以使用任何方法。
  2. 准备数据:树构造算法只适用于标称数据,因此数值型数据必须离散化。
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
  4. 训练算法:构造树的数据结构。
  5. 测试算法:使用经验树计算错误率。
  6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

一些决策树算法使用二分法划分数据,本书并不采用这种方法。如果依据某个属性划分数据将会产生4个可能的值,我们将把数据划分成四块,并创建四个不同的分支。

本书将使用ID3算法划分数据集,该算法处理如何划分数据集,何时停止划分数据集(进一步的信息可以参见http://en.wikipedia.org/wiki/ID3_algorithm)。每次划分数据集我们只选取一个特征属性,那么应该选择哪个特征作为划分的参考属性呢?

表1的数据包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚噗。我们可以将这些动物分成两类:鱼类和非鱼类。

表1 海洋生物数据

不浮出水面是否可以生存 是否有脚蹼 属于鱼类
1
2
3
4
5

1.1 信息增益

划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前或之后使用信息论量化度量信息的内容。

在划分数据集之前之后信息发生的变化成为信息增益,我们可以计算每个特征划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

对于某件事情

不确定性越大,熵越大,确定该事所需的信息量也越大;

不确定性越小,熵越小,确定该事所需的信息量也越小。

个人理解:将乱序数据转化为有序数据前后变化为信息增益,数据的信息的混乱程度叫)。

集合信息的度量方式成为香农熵或者简称为

熵定义为信息的期望值。我们先确定信息的定义:

如果待分类的事务可能划分在多个分类之中,则符号 \(x_i\) 定义为:

\[l(X_i) = -log_2\ p(x_i)
\]

其中 \(p(x_i)\) 是选择该分类的概率。

为了计算熵,我们需要计算所有类型所有可能值包含的信息的期望值,通过下面的公式得到:

\[H(x) = -\sum_{i = 1}^nP(X_i)log_2\ P(X_i)
\]

其中 \(n\) 是分类的数目。

下面给出计算信息熵的 Python 函数,创建名为 trees.py 文件,添加如下代码:

from math import log

# H(x) = -\sum_{i = 1}^nP(X_i)log_2P(X_i)
def calsShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {} # 为所有可能的字创建字典
for dataVec in dataSet:
label = dataVec[-1]
if label not in labelCounts.keys(): # 为所有可能分类创建字典
labelCounts[label] = 0
labelCounts[label] += 1 shannonEnt = 0.0
for key in labelCounts.keys():
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2) # 以2为底求对数
return shannonEnt

代码说明:

  • 首先,计算数据集中实例的总数。我们可以在需要时再计算这个值,但是由于代码中多次用到这个值,为了提高代码效率,我们显式地声明一个变量保存实例总数。
  • 然后,创建一个数据字典,它的键值是最后一列的数值。如果当前键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的粗疏。
  • 最后,使用所有类标签的发生频率计算类别出现的概率。我们将用这个概率计算香农熵,统计所有类标签发生的次数。

trees.py 文件中,我们利用 createDateSet() 函数得到一些样例数据:

def creatDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
]
labels = ['no surfacng', 'flippers']
return dataSet, labels

熵越高,则混合的数据也越多。得到熵之后,我们就可以按照获得最大信息增益的方法划分数据集。

另一个度量集合无序程度的方法是基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。

1.2 划分数据集

我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集市最好的划分方法。

添加划分数据集的代码:

def splitDataSet(dataSet, axis, value):
retDataSet = [] # 创建新的list对象
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec) # 抽取
return retDataSet

该函数使用了三个输入参数:带划分的数据集、划分数据集的特征、需要返回的特征的值。函数先选取数据集中第axis个特征值为value的数据,从这部分数据中去除第axis个特征,并返回。

测试这个函数,效果如下:

>>> import trees
>>> myDat, labels = trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.splitDataSet(myDat,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
>>> trees.splitDataSet(myDat,0,0)
[[1, 'no'], [1, 'no']]

接下来我们将遍历整个数据集,循环计算香农熵和 splitDataSet() 函数,找到最好的特征划分方式。

def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calsShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calsShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature

函数选取了第一个特征用于划分。

1.3 递归构建决策树

构造决策树所需的子功能模块已经介绍完毕,构建决策树的算法流程如下:

  1. 得到原始数据集,
  2. 基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。
  3. 第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。我们可以采用递归的原则处理数据集。
  4. 递归结束的条件是,程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。

参加图2所示:

trees.py 中添加下面的程序代码:

import operator

def majority(classList):
classCount = {}
for vote in classList:
if vote not in classCount.key(): classCount[vote] = 0
classCount[vote] += 1
sortedclassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1),
reverse=True)
return sortedclassCount[0][0] # 创建树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 类型完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时返回出现次数最多的
if len(dataSet[0]) == 1:
return majority(classList)
bestFeat = chooseBestFeatureToSplit(dataSet=dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
# 得到列表包含的所有属性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
sublabels = labels[:] # 复制labels列表
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), sublabels) # 递归构造子树
return myTree

majorityCnt 函数统计 classList 列表中每个类型标签出现频率,返回出现次数最多的分类名称。

createTree 函数使用两个输入参数:数据集 dataSet 和标签列表 labels

标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,我们将它作为一个输入参数提供。

上述代码首先创建了名为 classList 的列表变量,其中包含了数据集的所有类标签。列表变量classList 包含了数据集的所有类标签。递归函数的第一个停止条件是所有类标签完全相同,则直接返回该类标签。递归函数的第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。这里使用 majorityCnt 函数挑选出现次数最多的类别作为返回值。

下一步程序开始创建树,这里直接使用 Python 的字典类型存储树的信息。字典变量 myTree 存储树的所有信息。当前数据集选取的最好特征存储在变量 bestFeat 中,得到列表中包含的所有属性值。

最后代码遍历当前选择特征包含的所有属性值,在每个数据集划分上递归待用函数 createTree() ,得到的返回值将被插入到字典变量 myTree 中,因此函数终止执行时,字典中将会嵌套很多代表叶子节点信息的字典数据。

注意其中的 subLabels = labels[:] 复制了类标签,因为在递归调用 createTree 函数中会改变标签列表的值。

测试这些函数:

>>> import trees
>>> myDat, labels = trees.createDataSet()
>>> myTree = trees.createTree(myDat,labels)
>>> myTree
{'no surfacng': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

二、使用 Matplotlib 注解绘制树形图

上节我们已经学习了如何从数据集中创建树,然而字典的表示形式非常不易于理解,而且直 接绘制图形也比较困难。本节我们将使用 Matplotlib 库创建树形图。决策树的主要优点就是直观 易于理解,如果不能将其直观地显示出来,就无法发挥其优势。虽然前面章节我们使用的图形库 已经非常强大,但是Python并没有提供绘制树的工具,因此我们必须自己绘制树形图。本节我们 将学习如何编写代码绘制如 图3 所示的决策树。

2.1 Matplotlib 注解

Matplotlib 提供了一个注解工具 annotations,非常有用,它可以在数据图形上添加文本注 释。注解通常用于解释数据的内容。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支 持带箭头的划线工具,使得我们可以在其他恰当的地方指向数据位置,并在此处添加描述信息, 解释数据内容。如图4所示,在坐标 \((0.2, 0.1)\) 的位置有一个点,我们将对该点的描述信息放在 \((0.35, 0.3)\) 的位置,并用箭头指向数据点 \((0.2, 0.1)\)​。

使用 Matplotlib 的注解功能绘制树形图,它可以对文字着色并提供多种形状以供选择, 而且我们还可以反转箭头,将它指向文本框而不是数据点。打开文本编辑器,创建名为 treePlotter.py 的新文件,然后输入下面的程序代码。

import matplotlib.pyplot as plt

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-") # 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.axl.annotate(nodeTxt,
xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center",
ha="center",
bbox=nodeType,
arrowprops=arrow_args) # createPlot 版本一
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf() # 清空绘图区
createPlot.axl = plt.subplot(111, frameon=False)
plotNode(U'决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode(U'叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show() createPlot()

基于这个例子,现在开始学习绘制整棵树。

2.2 构造注解树

绘制一棵完整的树需要一些技巧。我们虽然有 \(x,y\) 坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定 \(x\) 轴的长度;我们还需要知道树有多少层,以便可以正确确定 \(y\)​ 轴的高度。这里我们定义两个新函数 getNumLeafs()getTreeDepth() ,来 获取叶节点的数目和树的层数,参见下面程序,并将这两个函数添加到文件 treePlotter.py 中。

这段代码有与原书不一样之处,原因在于Python版本不同。主要是以下两个方面:

  1. 1.firstStr 的创建不同:具体问题请点击:(firstStr创建问题)
  2. if判断语句不同:具体问题请点击:(if判断语句不同)
# 获取叶节点个数
def getNumLeafs(myTree):
numLeafs = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] # 找到输入的第一个元素
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs # 获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] # 找到输入的第一个元素
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth

上述程序中的两个函数具有相同的结构,后面我们也将使用到这两个函数。

这里使用的数据结构说明了如何在 Python 字典类型中存储树信息。第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字出发,我们可以遍历整棵树的所有子节点。 使用Python提供的type()函数可以判断子节点是否为字典类型 。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用 getNumLeafs() 函数。getNumLeafs()函数遍历整棵树,累计叶子节点的个数,并返回该数值。第2个函数 getTreeDepth() 计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。为了节省大家的时间,函数 retrieveTree 输出预先存储的树信息,避 免了每次测试代码时都要从数据中创建树的麻烦。 添加下面的代码到文件 treePlotter.py 中:

#输出预先存储的树信息,避免每次测试代码都从数据中创建树的麻烦
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTrees[i] print('retrieveTree(0) : \n{}'.format(retrieveTree(0)))
print('retrieveTree(1) : \n{}'.format(retrieveTree(1))) myTree = retrieveTree(0)
print('树的叶子结点个数为:\n{}'.format(getNumLeafs(myTree)))
print('树的深度为: \n{}'.format(getTreeDepth(myTree)))

2.3.构造注解树

#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString) #画一棵树
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #计算树的宽
depth = getTreeDepth(myTree) #计算树的高
firstStr = list(myTree.keys())[0]
plotTree.totalW = float(getNumLeafs(myTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
#cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt) #标记子节点属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt,
leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD #createPlot 版本二
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axpropps = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axpropps)
plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
plotTree.xOff = -0.5 / plotTree.totalW #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show() myTree = retrieveTree(0)
createPlot(myTree)

注:我在执行过程中发现,图像无法完全展示,所以我点击设置调整了图形大小及位置,调正后如下图。

2.4.变更字典

#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.axl.text(xMid,yMid,txtString) #画一棵树
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree) #计算树的宽
depth = getTreeDepth(myTree) #计算树的高
firstStr = list(myTree.keys())[0]
plotTree.totalW = float(getNumLeafs(myTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
#cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt) #标记子节点属性值
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #createPlot 版本二
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axpropps = dict(xticks = [],yticks = [])
createPlot.axl = plt.subplot(111, frameon = False, **axpropps)
plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
plotTree.xOff = -0.5/plotTree.totalW #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
myTree = retrieveTree(0)
myTree['no surfacing'][3] = 'maybe'
print('myTree : \n{}'.format(myTree))
createPlot(myTree)

运行结果如下

myTree :
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}

三、测试和存储分类器

3.1 测试算法:使用决策树进行分类

依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要决策树以及用于决策树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子结点;最后将测试数据定义为叶子结点所属的类型。

使用决策树分类的函数:添加进 trees.py

# 使用决策树
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel mydata, labels = createDataset()
mytree = createTree(mydata, labels)
print(classify(mytree, labels, [1, 1])) # 程序报错ValueError: 'no surfacing' is not in list
# 因为createTree()函数中删除了最佳划分特征的标签 del(labels[bestFeat])
# 把 del(labels[bestFeat]) 注释掉便可以输出 yes

3.2 使用算法:决策树的存储

可以使用 Python模块 pickle 序列化对象,参见下面的程序。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。

def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close() def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)

四、示例:使用决策树预测隐形眼镜类型

隐形眼镜数据集市非常著名的数据集,它包含很多患者眼部状态的观察条件以及医生推荐的因性眼睛类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。数据来源于UCI数据库,为了更容易显示数据,本书对数据做了简单的更改,数据存储在源代码下载路径的文本文件中。

# 实例:使用决策树预测隐形眼镜类型
f = open('D:\Coding\Py\Machine-Learning\Decision-Tree\lenses.txt')
lenses = [line.strip().split('\t') for line in f.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lenses)
print(lensesLabels)
print(lensesTree)
# 绘制决策树
import treePlotter treePlotter.createPlot(lensesTree)

决策树很好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子结点。如果叶子结点只能增加少许信息,则可以删除该节点,将他并入到其他叶子结点中。第9章将进一步讨论这个问题。

第九章将学习另一个决策树构造算法CART,本章使用的算法成为ID3,它是一个好的算法但并不完美。ID3算法无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数据,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。


附录:

  1. 关于基尼不纯度(Gini impurity)的更多信息,请参考Pan-Ning Tan, Vipin Kumar and Michael Steinbach, Introduction to Data Mineing. Pearson Eduction (Addison-Wesley, 2005), 158.

  2. 隐形眼镜数据集:The dataset is a modified version of the Lenses dataset retrieved from the UCI Machine Learning Repository November 3, 2001 [http://archive.ics.uci.edu/ml/machine-learning-databases/lenses/]. The source of the data is Jadzia Cendrowska and was originally published in “PRISM: An algorithm for inducing modular rules,” in International Journal of Man-Machine Studies (1987), 27, 349-70. 本书使用的数据的下载链接在:[链接]

  3. Hu_Pengxue机器学习实战学习笔记系列

《机器学习实战》 | 第3章 决策树(含Matplotlib模块介绍)的更多相关文章

  1. 《机器学习实战之第二章k-近邻算法》

    入坑<机器学习实战>: 本书的第一个机器学习算法是k-近邻算法(kNN),它的工作原理是:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据 ...

  2. 机器学习实战书-第二章K-近邻算法笔记

    本章介绍第一个机器学习算法:A-近邻算法,它非常有效而且易于掌握.首先,我们将探讨女-近邻算法的基本理论,以及如何使用距离测量的方法分类物品:其次我们将使用?7««^从文本文件中导人并解析数据: 再次 ...

  3. 《机器学习实战第7章:利用AdaBoost元算法提高分类性能》

    import numpy as np import matplotlib.pyplot as plt def loadSimpData(): dataMat = np.matrix([[1., 2.1 ...

  4. 《机器学习实战》---第二章 k近邻算法 kNN

    下面的代码是在python3中运行, # -*- coding: utf-8 -*- """ Created on Tue Jul 3 17:29:27 2018 @au ...

  5. 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现

    笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...

  6. 《机器学习实战》——k-近邻算法Python实现问题记录(转载)

    py2.7 : <机器学习实战> k-近邻算法 11.19 更新完毕 原文链接 <机器学习实战>第二章k-近邻算法,自己实现时遇到的问题,以及解决方法.做个记录. 1.写一个k ...

  7. Python 机器学习实战 —— 无监督学习(下)

    前言 在上篇< Python 机器学习实战 -- 无监督学习(上)>介绍了数据集变换中最常见的 PCA 主成分分析.NMF 非负矩阵分解等无监督模型,举例说明使用使用非监督模型对多维度特征 ...

  8. 使用Python的pandas模块、mplfinance模块、matplotlib模块绘制K线图

    目录 pandas模块.mplfinance模块和matplotlib模块介绍 pandas模块 mplfinance模块和matplotlib模块 安装mplfinance模块.pandas模块和m ...

  9. 《机器学习实战》学习笔记第九章 —— 决策树之CART算法

    相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...

  10. 【机器学习实战】第3章 决策树(Decision Tree)

    第3章 决策树 <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/ ...

随机推荐

  1. Leetcode回文数

    直接上python代码 class Solution: def isPalindrome(self, x: int) -> bool: if x<0: //负数必不是回文数 return ...

  2. 设备唯一标识方法(Unique Identifier):如何在 Windows 系统上获取设备的唯一标识

    原文地址 设备唯一标识方法(Unique Identifier):如何在 Windows 系统上获取设备的唯一标识 zz 唯一的标识一个设备是一个基本功能,可以拥有很多应用场景,比如软件授权(如何保证 ...

  3. AI量化策略会:可以直接上实盘的策略构建方法

    一年一度的培训虽晚但到,这是BigQuant与大家走过的第五个培训年头,在过去的四年里看到很多学员的成长和蜕变,从一开始的懵懂无知,到现在对深度学习的信手拈来,BigQuant与各位学员们一样都收获颇 ...

  4. 安装服务器提示A debugger has been found running in your system. Please, unload it from memory and restart

    ​ 解决方法:运行msconfig,取消调试模式,重启电脑再安装

  5. 基于python的cat1模块的AT指令串口通信解析

    一 前记 使用cat1模块做产品的过程中,遇到了不少问题.其中很重要的一个就是怎么测试单个模块的好坏.这里笔者专门写了一个工具,来测试cat1模块的是否好用,这里做一个分享吧.   二 源码解析 这个 ...

  6. Windows上安装jenkins

    官网下载jenkins https://www.jenkins.io/zh/download/ 选择Windows版本下载,安装 注意,需要java11,17或21才能安装 java下载地址  htt ...

  7. 为什么要重写equals()?

    为什么要重写equals()? Equals和 == 的区别: ==:是个运算符, 判断是否相等,基本数据类型进行判断 也可判断两个对象相等,比较两个对象的哈希码值 Equals:是个Object类的 ...

  8. 【JVM】一文掌握JVM垃圾回收机制

    作为Java程序员,除了业务逻辑以外,随着更深入的了解,都无法避免的会接触到JVM以及垃圾回收相关知识.JVM调优是一个听起来很可怕,实际上很简单的事. 感到可怕,是因为垃圾回收相关机制都在JVM的C ...

  9. 中间件是开箱即用的吗?为什么要开发中间件adapter?

    本文分享自华为云社区<中间件是开箱即用的吗?为什么要开发中间件adapter?>,作者:张俭. 中间件在很多系统中都存在 在一个系统里面,或多或少地都会有中间件的存在,总会有数据库,其他的 ...

  10. H3C 存储换盘操作

    实际存储型号H3C CF8844 环境说明:H3C存储设备存在一个坏盘需要更换. 更换准备 1. 取出备件检查完毕后放置到安全场所(请严格按照<IT产品现场工程师通用服务规(维修篇)>操作 ...