本文始发于个人公众号:TechFlow,原创不易,求个关注

今天是机器学习专题的第22篇文章,我们继续决策树的话题。

上一篇文章当中介绍了一种最简单构造决策树的方法——ID3算法,也就是每次选择一个特征进行拆分数据。这个特征有多少个取值那么就划分出多少个分叉,整个建树的过程非常简单。如果错过了上篇文章的同学可以从下方传送门去回顾一下:

如果你还不会决策树,那你一定要进来看看

既然我们已经有了ID3算法可以实现决策树,那么为什么还需要新的算法?显然一定是做出了一些优化或者是进行了一些改进,不然新算法显然是没有意义的。所以在我们学习新的算法之前,需要先搞明白,究竟做出了什么改进,为什么要做出这些改进。

一般来说,改进都是基于缺点和不足的,所以我们先来看看ID3算法的一些问题。

其中最大的问题很明显,就是它无法处理连续性的特征。不能处理的原因也很简单,因为ID3在每次在切分数据的时候,选择的不是一个特征的取值,而是一个具体的特征。这个特征下有多少种取值就会产生多少个分叉,如果使用连续性特征的话,比如说我们把西瓜的直径作为特征的话。那么理论上来说每个西瓜的直径都是不同的,这样的数据丢进ID3算法当中就会产生和样本数量相同的分叉,这显然是没有意义的。

其实还有一个问题,藏得会比较深一点,是关于信息增益的。我们用划分前后的信息熵的差作为信息增益,然后我们选择带来最大信息增益的划分。这里就有一个问题了,这会导致模型在选择的时候,倾向于选择分叉比较多的特征。极端情况下,就比如说是连续性特征好了,每个特征下都只有一个样本,那么这样算出来得到的信息熵就是0,这样得到的信息增益也就非常大。这是不合理的,因为分叉多的特征并不一定划分效果就好,整体来看并不一定是有利的。

针对这两个问题,提出了改进方案,也就是说C4.5算法。严格说起来它并不是独立的算法,只是ID3算法的改进版本。

下面我们依次来看看C4.5算法究竟怎么解决这两个问题。

信息增益比

首先,我们来看信息增益的问题。前面说了,如果我们单纯地用信息增益去筛选划分的特征,那么很容易陷入陷阱当中,选择了取值更多的特征。

针对这个问题,我们可以做一点调整,我们把信息增益改成信息增益比。所谓的信息增益比就是用信息增益除以我们这个划分本身的信息熵,从而得到一个比值。对于分叉很多的特征,它的自身的信息熵也会很大。因为分叉多,必然导致纯度很低。所以我们这样可以均衡一下特征分叉带来的偏差,从而让模型做出比较正确的选择。

我们来看下公式,真的非常简单:

这里的D就是我们的训练样本集,a是我们选择的特征,IV(a)就是这个特征分布的信息熵

我们再来看下IV的公式:

解释一下这里的值,这里的V是特征a所有取值的集合。自然就是每一个v对应的占比,所以这就是一个特征a的信息熵公式。

处理连续值

C4.5算法对于连续值同样进行了优化,支持了连续值,支持的方式也非常简单,对于特征a的取值集合V来说,我们选择一个阈值t进行划分,将它划分成小于t的和大于t的两个部分。

也就是说C4.5算法对于连续值的切分和离散值是不同的,对于离散值变量,我们是对每一种取值进行切分,而对于连续值我们只切成两份。其实这个设计非常合理,因为对于大多数情况而言,每一条数据的连续值特征往往都是不同的。而且我们也没有办法很好地确定对于连续值特征究竟分成几个部分比较合理,所以比较直观的就是固定切分成两份,总比无法用上好。

在极端情况下,连续值特征的取值数量等于样本条数,那么我们怎么选择这个阈值呢?即使我们遍历所有的切分情况,也有n-1种,这显然是非常庞大的,尤其在样本数量很大的情况下。

针对这个问题,也有解决的方法,就是按照特征值排序,选择真正意义上的切分点。什么意思呢,我们来看一份数据:

直径 是否甜
3
4
5 不甜
6 不甜
7
8 不甜
9 不甜

这份数据是我们队西瓜直径这个特征排序之后的结果,我们可以看出来,训练目标改变的值其实只有3个,分别是直径5,7还有8的时候,我们只需要考虑这三种情况就好了,其他的情况可以不用考虑。

我们综合考虑这两点,然后把它们加在之前ID3模型的实现上就好了。

代码实现

光说不练假把式,我们既然搞明白了它的原理,就得自己亲自动手实现以下才算是真的理解,很多地方的坑也才算是真的懂。我们基本上可以沿用之前的代码,不过需要在之前的基础上做一些修改。

首先我们先来改造构造数据的部分,我们依然沿用上次的数据,学生的三门考试等级以及它是否通过达标的数据。我们认为三门成绩在150分以上算是达标,大于70分的课程是2等级,40-70分之间是1等级,40分以下是0等级。在此基础上我们增加了分数作为特征,我们在分数上增加了一个误差,避免模型直接得到结果。

import numpy as np
import math
def create_data():
X1 = np.random.rand(50, 1)*100
X2 = np.random.rand(50, 1)*100
X3 = np.random.rand(50, 1)*100 def f(x):
return 2 if x > 70 else 1 if x > 40 else 0 # 学生的分数作为特征,为了公平,加上了一定噪音
X4 = X1 + X2 + X3 + np.random.rand(50, 1) * 20 y = X1 + X2 + X3
Y = y > 150
Y = Y + 0
r = map(f, X1)
X1 = list(r) r = map(f, X2)
X2 = list(r) r = map(f, X3)
X3 = list(r)
x = np.c_[X1, X2, X3, X4, Y]
return x, ['courseA', 'courseB', 'courseC', 'score']

由于我们需要计算信息增益比,所以需要开发一个专门的函数用来计算信息增益比。由于这一次的数据涉及到了连续型特征,所以我们需要多传递一个阈值,来判断是否是连续性特征。如果是离散型特征,那么阈值为None,否则为具体的值。

def info_gain(dataset, idx):
# 计算基本的信息熵
entropy = calculate_info_entropy(dataset)
m = len(dataset)
# 根据特征拆分数据
split_data, _ = split_dataset(dataset, idx)
new_entropy = 0.0
# 计算拆分之后的信息熵
for data in split_data:
prob = len(data) / m
# p * log(p)
new_entropy += prob * calculate_info_entropy(data)
return entropy - new_entropy def info_gain_ratio(dataset, idx, thred=None):
# 拆分数据,需要将阈值传入,如果阈值不为None直接根据阈值划分
# 否则根据特征值划分
split_data, _ = split_dataset(dataset, idx, thred)
base_entropy = 1e-5
m = len(dataset)
# 计算特征本身的信息熵
for data in split_data:
prob = len(data) / m
base_entropy -= prob * math.log(prob, 2)
return info_gain(dataset, idx) / base_entropy, thred

split_dataset函数也需要修改,因为我们拆分的情况多了一种根据阈值拆分,通过判断阈值是否为None来判断进行阈值划分还是特征划分。

def split_dataset(dataset, idx, thread=None):
splitData = defaultdict(list)
# 如果阈值为None那么直接根据特征划分
if thread is None:
for data in dataset:
splitData[data[idx]].append(np.delete(data, idx))
return list(splitData.values()), list(splitData.keys())
else:
# 否则根据阈值划分,分成两类大于和小于
for data in dataset:
splitData[data[idx] < thread].append(np.delete(data, idx))
return list(splitData.values()), list(splitData.keys())

前面说了我们在选择阈值的时候其实并不一定要遍历所有的取值,因为有些取值并不会引起label分布的变化,对于这种取值我们就可以忽略。所以我们需要一个函数来获取阈值所有的可能性,这个也很简单,我们直接根据阈值排序,然后遍历观察label是否会变化,记录下所有label变化位置的值即可:

def get_thresholds(X, idx):

    # numpy多维索引用法
new_data = X[:, [idx, -1]].tolist()
# 根据特征值排序
new_data = sorted(new_data, key=lambda x: x[0], reverse=True)
base = new_data[0][1]
threads = [] for i in range(1, len(new_data)):
f, l = new_data[i]
# 如果label变化则记录
if l != base:
base = l
threads.append(f) return threads

有了这些方法之后,我们需要开发选择拆分值的函数,也就是计算所有特征的信息增益比,找到信息增益比最大的特征进行拆分。其实我们将前面拆分和获取所有阈值的函数都开发完了之后,要寻找最佳的拆分点就很容易了,基本上就是利用一下之前开发好的代码,然后搜索一下所有的可能性:

def choose_feature_to_split(dataset):
n = len(dataset[0])-1
m = len(dataset)
# 记录最佳增益比、特征和阈值
bestGain = 0.0
feature = -1
thred = None
for i in range(n):
# 判断是否是连续性特征,默认整数型特征不是连续性特征
# 这里只是我个人的判断逻辑,可以自行diy
if not dataset[0][i].is_integer():
threds = get_thresholds(dataset, i)
for t in threds:
# 遍历所有的阈值,计算每个阈值的信息增益比
ratio, th = info_gain_ratio(dataset, i, t)
if ratio > bestGain:
bestGain, feature, thred = ratio, i, t
else:
# 否则就走正常特征拆分的逻辑,计算增益比
ratio, _ = info_gain_ratio(dataset, i)
if ratio > bestGain:
bestGain = ratio
feature, thred = i, None
return feature, thred

到这里,基本方法就开发完了,只剩下建树和预测两个方法了。这两个方法和之前的代码改动都不大,基本上就是细微的变化。我们先来看建树,建树唯一的不同点就是在dict当中需要额外存储一份阈值的信息。如果是None表示离散特征,不为None为连续性特征,其他的逻辑基本不变。

def create_decision_tree(dataset, feature_names):
dataset = np.array(dataset)
# 如果都是一类,那么直接返回类别
counter = Counter(dataset[:, -1])
if len(counter) == 1:
return dataset[0, -1] # 如果只有一个特征了,直接返回占比最多的类别
if len(dataset[0]) == 1:
return counter.most_common(1)[0][0] # 记录最佳拆分的特征和阈值
fidx, th = choose_feature_to_split(dataset)
fname = feature_names[fidx] node = {fname: {'threshold': th}}
feature_names.remove(fname) split_data, vals = split_dataset(dataset, fidx, th)
for data, val in zip(split_data, vals):
node[fname][val] = create_decision_tree(data, feature_names[:])
return node

最后是预测的函数,逻辑和之前一样,只不过加上了阈值是否为None的判断而已,应该非常简单:

def classify(node, feature_names, data):
key = list(node.keys())[0]
node = node[key]
idx = feature_names.index(key) pred = None
thred = node['threshold']
# 如果阈值为None,那么直接遍历dict
if thred is None:
for key in node:
if key != 'threshold' and data[idx] == key:
if isinstance(node[key], dict):
pred = classify(node[key], feature_names, data)
else:
pred = node[key]
else:
# 否则直接访问
if isinstance(node[data[idx] < thred], dict):
pred = classify(node[data[idx] < thred], feature_names, data)
else:
pred = node[data[idx] < thred] # 放置pred为空,挑选一个叶子节点作为替补
if pred is None:
for key in node:
if not isinstance(node[key], dict):
pred = node[key]
break
return pred

总结

到这里整个决策树的C4.5算法就开发完了,整体来说由于加上了信息增益比以及连续性特征的逻辑,所以整体的代码比之前要复杂一些,但是基本上的逻辑和套路都是一脉相承的,基本上没什么太大的变化。

决策树说起原理来非常简单,但是很多细节如果没有亲自做过是意识不到的。比如说连续性特征的阈值集合应该怎么找,比如说连续性特征和离散型的特征混合的情况,怎么在代码当中区分,等等。只有实际动手做过,才能意识到这些问题。虽然平时也用不到决策树这个模型,但是它是很多高级模型的基础,吃透它对后面的学习和进阶非常有帮助,如果有空,推荐大家都亲自试一试。

今天的文章就到这里,原创不易,需要你的一个关注,你的举手之劳对我来说很重要。

深入了解机器学习决策树模型——C4.5算法的更多相关文章

  1. 决策树之C4.5算法

    决策树之C4.5算法 一.C4.5算法概述 C4.5算法是最常用的决策树算法,因为它继承了ID3算法的所有优点并对ID3算法进行了改进和补充. 改进有如下几个要点: 用信息增益率来选择属性,克服了ID ...

  2. 决策树之C4.5算法学习

    决策树<Decision Tree>是一种预測模型,它由决策节点,分支和叶节点三个部分组成. 决策节点代表一个样本測试,通常代表待分类样本的某个属性,在该属性上的不同測试结果代表一个分支: ...

  3. 决策树模型 ID3/C4.5/CART算法比较

    决策树模型在监督学习中非常常见,可用于分类(二分类.多分类)和回归.虽然将多棵弱决策树的Bagging.Random Forest.Boosting等tree ensembel 模型更为常见,但是“完 ...

  4. 机器学习之决策树(ID3 、C4.5算法)

    声明:本篇博文是学习<机器学习实战>一书的方式路程,系原创,若转载请标明来源. 1 决策树的基础概念 决策树分为分类树和回归树两种,分类树对离散变量做决策树 ,回归树对连续变量做决策树.决 ...

  5. 机器学习总结(八)决策树ID3,C4.5算法,CART算法

    本文主要总结决策树中的ID3,C4.5和CART算法,各种算法的特点,并对比了各种算法的不同点. 决策树:是一种基本的分类和回归方法.在分类问题中,是基于特征对实例进行分类.既可以认为是if-then ...

  6. 《机器学习实战》学习笔记第三章 —— 决策树之ID3、C4.5算法

    主要内容: 一.决策树模型 二.信息与熵 三.信息增益与ID3算法 四.信息增益比与C4.5算法 五.决策树的剪枝 一.决策树模型 1.所谓决策树,就是根据实例的特征对实例进行划分的树形结构.其中有两 ...

  7. 机器学习(Machine Learning)算法总结-决策树

    一.机器学习基本概念总结 分类(classification):目标标记为类别型的数据(离散型数据)回归(regression):目标标记为连续型数据 有监督学习(supervised learnin ...

  8. DNS通道检测 国内学术界研究情况——研究方法:基于特征或者流量,使用机器学习决策树分类算法居多

    http://xuewen.cnki.net/DownloadArticle.aspx?filename=BMKJ201104017&dbtype=CJFD<浅析基于DNS协议的隐蔽通道 ...

  9. 02-22 决策树C4.5算法

    目录 决策树C4.5算法 一.决策树C4.5算法学习目标 二.决策树C4.5算法详解 2.1 连续特征值离散化 2.2 信息增益比 2.3 剪枝 2.4 特征值加权 三.决策树C4.5算法流程 3.1 ...

随机推荐

  1. LeetCode 56,区间合并问题

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是LeetCode专题的第33篇文章,我们一起来看LeetCode的第56题,它的难度是Medium. 题意 这道题的题意也很简单,只有 ...

  2. Android | 带你零代码实现安卓扫码功能

    目录 小序 背景介绍 前期准备 开始搬运 结语 小序   这是一篇纯新手教学,本人之前没有任何安卓开发经验(尴尬),本文也不涉及任何代码就可以使用一个扫码demo,华为scankit真是新手的福音-- ...

  3. thinkphp5.x系列 RCE总结

    Thinkphp  MVC开发模式 执行流程: 首先发起请求->开始路由检测->获取pathinfo信息->路由匹配->开始路由解析->获得模块.控制器.操作方法调度信息 ...

  4. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  5. A - ACM Computer Factory POJ - 3436 网络流

    A - ACM Computer Factory POJ - 3436 As you know, all the computers used for ACM contests must be ide ...

  6. Linux查看redis占用内存的方法

    redis-cli auth 密码info # Memory used_memory:13490096 //数据占用了多少内存(字节) used_memory_human:12.87M //数据占用了 ...

  7. 【Scala】新手入门,基础语法概览

    目录 变量.常量和数据类型 var val 数据类型 条件表达式 块表达式 to循环 for循环 for推导式 scala中的方法和函数 方法的定义 函数的定义 函数和方法的区别 变量.常量和数据类型 ...

  8. ASP.NET Core Blazor 初探之 Blazor Server

    上周初步对Blazor WebAssembly进行了初步的探索(ASP.NET Core Blazor 初探之 Blazor WebAssembly).这次来看看Blazor Server该怎么玩. ...

  9. 帝国cms 批量删除包含关键字的 内容

    删除包含关键字的 内容delete from www_kaifatu_com_ecms_news where playurl like '%关键字%'

  10. SpringBatch异常To use the default BatchConfigurer the context must contain no more thanone DataSource

    SpringBoot整合SpringBatch项目,已将代码开源至github,访问地址:https://github.com/cmlbeliever/SpringBatch 欢迎star or fo ...