决策树分类算法是一种监督学习算法,它的基本原理是将数据集通过一系列的问题进行拆分,这些问题被视为决策树的叶子节点和内部节点。
决策树的每个分支代表一个可能的决策结果,而每个叶子节点代表一个最终的分类结果。

决策树分类算法的历史可以追溯到1980年代初,当时研究者开始探索用机器学习来解决分类问题。
在1981年,J.Ross Quinlan开发了ID3算法,该算法使用信息增益来选择决策树的最佳划分属性。
后来,在1986年,J.Ross Quinlan提出了C4.5算法,该算法引入了剪枝技术,以防止过拟合,该算法还引入了处理连续属性、缺失数据和多值属性等新特性。
在1998年,Jerome Friedman等人提出了CART算法Classification and Regression Trees),该算法采用了二叉树,使得决策树更加简洁和易于解释。

1. 算法概述

决策树不仅可以用在分类问题上,也可以用在回归问题上。
关于决策树回归问题上的应用,可以参考:TODO

回到决策树分类算法上来,构建决策树的有三种算法:

1.1. ID3

ID3算法的完整名称是Iterative Dichotomiser 3,即迭代二叉树3代。
ID3算法的核心思想是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。

对于任意样本数据 \(x(x_1,x_2,...,x_n)\),它的信息熵定义为:
\(entropy(x) = -\sum_{i=1}^n p_i\log_2(p_i)\)

基于信息熵,信息增益的公式为:
\(IG(T) = entropy(S) - \sum_{value(T)}\frac{|S_x|}{|S|}entropy(S_x)\)
其中:

  • \(S\) 表示全部样本的集合
  • \(|S|\) 表示\(S\)中样本数量
  • \(T\) 表示样本的某个特征
  • \(value(T)\) 表示特征\(T\)所有的取值集合
  • \(S_x\) 是\(S\)中特征\(T\)的值为\(x\)的样本的集合
  • \(|S_x|\) 表示\(S_x\)中样本数量

1.2. C4.5

C4.5算法是以ID3算法为基础的,它改为使用信息增益率来作为决策树分裂的依据。
这样,就克服了ID3算法中信息增益选择属性时偏向选择取值多的属性的不足。

C4.5算法中引入了一个分裂信息(split information)的项来惩罚取值较多的特征:
\(SI(T) = - \sum_{value(T)}\frac{|S_x|}{|S|}\log\frac{|S_x|}{|S|}\)

基于此,信息增益率的公式为:
\(gainRatio(T)=\frac{IG(T)}{SI(T)}\)
\(IG(T)\)就是上一节ID3算法中的信息增益公式。

1.3. CART

CART算法全称是 classification and regression tree(分类与回归树)。
这个算法既可以用来分类,也可以用来回归,在回归问题上的介绍可以参考。

CART算法是根据基尼系数(Gini)来划分特征的,每次选择基尼系数最小的特征作为最优切分点。
其中基尼系数的计算方法:\(gini(p) = \sum_{i=1}^n p_i(1-p_i)=1-\sum_{i=1}^n p_i^2\)

2. 创建样本数据

scikit-learn中的样本生成器make_classification来生成分类用的样本数据。

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification # 分类数据的样本生成器
X, y= make_classification(n_samples=1000, n_classes=4, n_clusters_per_class=1, n_informative=6)
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=25) plt.show()


关于样本生成器的详细内容,请参考:TODO

3. 模型训练

首先,分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

这次按照8:2的比例来划分训练集和测试集。

然后用不同的算法来训练决策树模型:

from sklearn.tree import DecisionTreeClassifier

reg_names = [
"ID3算法",
"C4.5算法",
"CART算法",
] # 定义
regs = [
DecisionTreeClassifier(criterion="entropy"),
DecisionTreeClassifier(criterion="log_loss"),
DecisionTreeClassifier(criterion="gini"),
] # 训练模型
for reg in regs:
reg.fit(X_train, y_train) # 在测试集上进行预测
y_preds = []
for reg in regs:
y_pred = reg.predict(X_test)
y_preds.append(y_pred) for i in range(len(y_preds)):
correct_pred = np.sum(y_preds[i] == y_test)
print("【{}】 预测正确率:{:.2f}%".format(reg_names[i], correct_pred / len(y_pred) * 100)) # 运行结果
【ID3算法】 预测正确率:71.50%
【C4.5算法】 预测正确率:72.50%
【CART算法】 预测正确率:75.00%

算法的正确率差别不是特别大。
感兴趣的朋友,可以尝试调整样本生成器部分,生成一些特征较多的数据来看看算法之间的性能差别。

4. 总结

决策树分类算法广泛应用于图像识别、文本分类、语音识别、信用评分、疾病诊断等众多领域。
例如,在电商平台上,可以通过决策树分类算法对用户的行为数据进行挖掘和分析,实现对用户的精准推荐;
在医疗领域,可以通过对医学数据的分析,辅助医生进行疾病诊断和治疗方案制定。

决策树分类算法的优势有:

  1. 易于理解和解释,直观地展示出分类的过程
  2. 对于数据集可以进行并行处理,提高了算法的效率
  3. 对于缺失数据和非数值属性有很好的处理能力
  4. 可以处理多分类问题

决策树分类算法也存在一些劣势

  1. 可能存在过拟合,需要使用剪枝技术来控制
  2. 可能存在偏向性,需要使用加权投票来处理
  3. 对于连续属性和多值属性处理起来比较复杂,需要额外的处理方法
  4. 大规模数据集处理起来比较耗时,需要优化算法或者使用分布式计算等方法

【scikit-learn基础】--『监督学习』之 决策树分类的更多相关文章

  1. Python基础『一』

    内置数据类型 数据名称 例子 数字: Bool,Complex,Float,Integer True/False; z=a+bj; 1.23; 123 字符串: String '123456' 元组: ...

  2. Python基础『二』

    目录 语句,表达式 赋值语句 打印语句 分支语句 循环语句 函数 函数的作用 函数的三要素 函数定义 DEF语句 RETURN语句 函数调用 作用域 闭包 递归函数 匿名函数 迭代 语句,表达式 赋值 ...

  3. 『cs231n』计算机视觉基础

    线性分类器损失函数明细: 『cs231n』线性分类器损失函数 最优化Optimiz部分代码: 1.随机搜索 bestloss = float('inf') # 无穷大 for num in range ...

  4. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  5. [原创] 【2014.12.02更新网盘链接】基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装

    [原创] [2014.12.02更新网盘链接]基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装 joinlidong 发表于 2014-11-29 14:25:50 ...

  6. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  7. 『TensorFlow』批处理类

    『教程』Batch Normalization 层介绍 基础知识 下面有莫凡的对于批处理的解释: fc_mean,fc_var = tf.nn.moments( Wx_plus_b, axes=[0] ...

  8. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  9. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  10. 『计算机视觉』Mask-RCNN_从服装关键点检测看KeyPoints分支

    下图Github地址:Mask_RCNN       Mask_RCNN_KeyPoints『计算机视觉』Mask-RCNN_论文学习『计算机视觉』Mask-RCNN_项目文档翻译『计算机视觉』Mas ...

随机推荐

  1. C#开源、功能强大、免费的Windows系统优化工具 - Optimizer

    前言 今天给大家推荐一款由C#开源.功能强大.免费的Windows系统优化工具 - Optimizer. 工具介绍 Optimizer是一款功能强大的Windows系统优化工具,可帮助用户提高计算机性 ...

  2. ora2pg使用记录

    ora2pg使用记录 前言 这篇文章是我在学习使用ora2pg过程中的学习记录,以便日后遗忘查阅: 诸君也可跟随我的步伐了解一下ora2pg,或可移步如下官方文档参考学习:Ora2Pg : Migra ...

  3. 文心一言 VS 讯飞星火 VS chatgpt (104)-- 算法导论10.1 2题

    二.用go语言,说明如何在一个数组 A[1..n]中实现两个栈,使得当两个栈的元素个数之和不为 n 时,两者都不会发生上溢.要求 PUSH 和 POP 操作的运行时间为 O(1). 文心一言: 在这个 ...

  4. Hugging Face 分词器新增聊天模板属性

    一个幽灵,格式不正确的幽灵,在聊天模型中游荡! 太长不看版 现存的聊天模型使用的训练数据格式各各不同,我们需要用这些格式将对话转换为单个字符串并传给分词器.如果我们在微调或推理时使用的格式与模型训练时 ...

  5. 实战|如何低成本训练一个可以超越 70B Llama2 的模型 Zephyr-7B

    每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...

  6. P9482 [NOI2023] 字符串 题解

    \(36pts\) \(O(tqn^2)\)暴力即可 \(40pts\) 对于最朴素的暴力优化,从头到尾扫,如果已经当前位字符比出优先级,那么直接能判断了,没必要往后跑了,第15个性质B的也给跑过了, ...

  7. Go 14周年

    原文在这里. 由 Russ Cox, for the Go team 发布于2023年11月10日 今天,我们庆祝Go开源发布的第十四个生日!Go在过去一年里取得了巨大的进展,发布了两个功能丰富的版本 ...

  8. JavaScript高级程序设计笔记06 集合引用类型

    集合引用类型 1. Object(详见c08 p205) 适合存储,在应用程序间交换数据 创建实例: a. 显式构造函数 b. 字面量-->不会调用构造函数(代码更少.更有封装感) 函数:大量参 ...

  9. CodeChef Starters 9 Division 3 (Rated) India Fights Corona

    原题链接 India Fights Corona 题意: 有\(n\)个城市,\(m\)条道路,其中有些城市自己有医院,所以可以在自己城市做核酸检测,那么花费就只有就医费用,而对于那些自己没有医院的城 ...

  10. MySQL大表设计

    存储大规模数据集需要仔细设计数据库模式和索引,以便能够高效地支持各种查询操作.在面对数亿条数据,每条数据包含数百个字段的情况下,以下是我能想到的在设计数据库的时候需要注意的内容,不足之处欢迎各位在评论 ...