决策树算法是一种既可以用于分类,也可以用于回归的算法。

决策树回归是通过对输入特征的不断划分来建立一棵决策树,每一步划分都基于当前数据集的最优划分特征。
它的目标是最小化总体误差或最大化预测精度,其构建通常采用自上而下的贪心搜索方式,通过比较不同划分标准来选择最优划分。

决策树回归广泛应用于各种回归问题,如预测房价、股票价格、客户流失等。

1. 算法概述

决策树相关的诸多算法之中,有一种CART算法,全称是 classification and regression tree(分类与回归树)。
顾名思义,这个算法既可以用来分类,也可以用来回归,本篇主要介绍其在回归问题上的应用。

决策树算法的核心在于生成一棵决策树过程中,如何划分各个特征到树的不同分支上去。
CART算法是根据基尼系数(Gini)来划分特征的,每次选择基尼系数最小的特征作为最优切分点。

其中基尼系数的计算方法:\(gini(p) = \sum_{i=1}^n p_i(1-p_i)=1-\sum_{i=1}^n p_i^2\)

2. 创建样本数据

这次的回归样本数据,我们用 scikit-learn 自带的玩具数据集中的糖尿病数据集
关于玩具数据集的内容,可以参考:TODO

from sklearn.datasets import load_diabetes

# 糖尿病数据集
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

这个数据集中大约有400多条数据。

3. 模型训练

训练之前,为了减少算法误差,先对数据进行标准化处理。

from sklearn import preprocessing as pp

# 数据标准化
X = pp.scale(X)
y = pp.scale(y)

接下来分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

然后用scikit-learn中的DecisionTreeRegressor模型来训练:

from sklearn.tree import DecisionTreeRegressor

# 定义决策树回归模型
reg = DecisionTreeRegressor(max_depth=2) # 训练模型
reg.fit(X_train, y_train) # 在测试集上进行预测
y_pred = reg.predict(X_test)

DecisionTreeRegressor的主要参数包括:

  1. criterion:用于衡量节点划分质量的指标。可以选择的值有'mse'(均方误差)或'mae'(平均绝对误差)。默认值为'mse',适用于大多数情况。
  2. splitter:用于决定节点如何进行划分的策略。可以选择的值有'best'(选择最佳划分)或'random'(随机划分)。默认值为'best'。
  3. max_depth:决策树的最大深度。默认值为None,表示不限制最大深度。增加最大深度有助于更好地拟合训练数据,但可能导致过拟合。
  4. random_state:用于设置随机数生成器的种子。默认值为None,表示使用随机数生成器。
  5. ccp_alpha:用于控制正则化强度的参数。默认值为None,表示不进行正则化。
  6. max_samples:用于控制每个节点最少需要多少样本才能进行分裂。默认值为None,表示使用整个数据集。
  7. min_samples_split:用于控制每个节点最少需要多少样本才能进行分裂。默认值为2,表示每个节点至少需要2个样本才能进行分裂。
  8. min_samples_leaf:用于控制每个叶子节点最少需要多少样本才能停止分裂。默认值为1,表示每个叶子节点至少需要1个样本才能停止分裂。
  9. min_weight_fraction_leaf:用于控制每个叶子节点最少需要多少样本的权重才能停止分裂。默认值为0.0,表示每个叶子节点至少需要0个样本的权重才能停止分裂。
  10. max_features:用于控制每个节点最多需要考虑多少个特征进行分裂。默认值为None,表示使用所有特征。
  11. max_leaf_nodes:用于控制决策树最多有多少个叶子节点。默认值为None,表示不限制叶子节点的数量。
  12. min_impurity_decrease:用于控制每个节点最少需要减少多少不纯度才能进行分裂。默认值为0.0,表示每个节点至少需要减少0个不纯度才能进行分裂。
  13. min_impurity_split:用于控制每个叶子节点最少需要减少多少不纯度才能停止分裂。默认值为None,表示使用min_impurity_decrease参数。
  14. class_weight:用于设置类别权重的字典或方法。默认值为None,表示使用均匀权重。

最后验证模型的训练效果:

from sklearn import metrics

# 在测试集上进行预测
y_pred = reg.predict(X_test) mse, r2, m_error = 0.0, 0.0, 0.0
y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:0.5973573097746598
复相关系数:0.5153160857515913
中位数绝对误差:0.5496418600646286

从预测的误差来看,训练的效果还不错
这里用DecisionTreeRegressor训练模型时使用了参数max_depth=2
我从max_depth=1逐个尝试到了max_depth=10,发现max_depth=2时误差最小。

4. 总结

决策树回归具有直观、易于理解、易于实现等优点。
生成的决策树可以直观地展示出输入特征与输出结果之间的关系,因此对于非专业人士来说也易于理解。
此外,决策树回归算法相对简单,易于实现,且对数据的预处理要求较低。

然而,决策树回归也存在一些缺点。
首先,它容易过拟合训练数据,特别是当训练数据量较小时;
其次,决策树的性能受划分标准选择的影响较大,不同的划分标准可能会导致生成的决策树性能差异较大;
此外,决策树回归在处理大规模数据时可能会比较耗时,因为需要遍历整个数据集进行训练和预测。

【scikit-learn基础】--『监督学习』之 决策树回归的更多相关文章

  1. Python基础『一』

    内置数据类型 数据名称 例子 数字: Bool,Complex,Float,Integer True/False; z=a+bj; 1.23; 123 字符串: String '123456' 元组: ...

  2. Python基础『二』

    目录 语句,表达式 赋值语句 打印语句 分支语句 循环语句 函数 函数的作用 函数的三要素 函数定义 DEF语句 RETURN语句 函数调用 作用域 闭包 递归函数 匿名函数 迭代 语句,表达式 赋值 ...

  3. 『cs231n』计算机视觉基础

    线性分类器损失函数明细: 『cs231n』线性分类器损失函数 最优化Optimiz部分代码: 1.随机搜索 bestloss = float('inf') # 无穷大 for num in range ...

  4. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  5. [原创] 【2014.12.02更新网盘链接】基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装

    [原创] [2014.12.02更新网盘链接]基于EasySysprep4.1的 Windows 7 x86/x64 『视频』封装 joinlidong 发表于 2014-11-29 14:25:50 ...

  6. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  7. 『TensorFlow』批处理类

    『教程』Batch Normalization 层介绍 基础知识 下面有莫凡的对于批处理的解释: fc_mean,fc_var = tf.nn.moments( Wx_plus_b, axes=[0] ...

  8. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  9. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  10. 『计算机视觉』Mask-RCNN_从服装关键点检测看KeyPoints分支

    下图Github地址:Mask_RCNN       Mask_RCNN_KeyPoints『计算机视觉』Mask-RCNN_论文学习『计算机视觉』Mask-RCNN_项目文档翻译『计算机视觉』Mas ...

随机推荐

  1. 低代码平台如何借助Nginx实现网关服务

    摘要:本文由葡萄城技术团队于博客园原创并首发.转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具.解决方案和服务,赋能开发者. 前言 在典型的系统部署架构中,应用服务器是一种软件或硬件系统, ...

  2. Redis系列之——高级用法

    文章目录 一 慢查询 1.1 生命周期 1.2 两个配置 1.2.1 slowlog-max-len 1.2.2 slowlog-max-len 1.2.3 配置方法 1.3 三个命令 1.4 经验 ...

  3. android 中ids.xml资源的使用

    ids.xml 前面我们见识过ids.xml文件,但是这个文件是什么意思呢?我们来看下文档中的介绍: 先看下它给的例子: XML file saved at res/values/ids.xml: 使 ...

  4. 【BUU刷题日记】——第一周

    [BUU刷题日记]--第一周 一.[极客大挑战 2019]PHP1 题目说自己有一个备份网站的习惯,所以要了解一下常见的网站源码备份格式及文件名: 格式:tar.tar.gz.zip.rar 文件名: ...

  5. Shuffle 题解

    Shuffle 题目大意 给定一个长度为 \(n\) 的 01 序列 \(a\),你可以进行至多一次以下操作: 选定 \(a\) 的一个连续段,满足连续段内恰好有 \(k\) 个 \(1\),将该连续 ...

  6. C# -WebAPIOperator.cs

    说明:一个用C#编写的WebAPI操作类,只写了Get Post 部分. using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System ...

  7. NewStarCTF 2023 公开赛道 WEEK4|MISC 部分WP

    R通大残 1.题目信息 R通大残,打了99,补! 2.解题方法 仔细分析题目,联想到隐写的R通道. 首先解释一下:R是储存红色的通道,通道里常见有R(红).G(绿).B(蓝)三个通道,如果关闭了R通道 ...

  8. MySQL索引、事务与存储引擎

    MySQL索引.事务与存储引擎 索引介绍 1.索引的概念 索引是一个排序的列表,在这个列表中存储着索引的值和包含这个值的数据所在行的物理地址(类似于C语言的链表通过指针指向数据记录的内存地址). 使用 ...

  9. CSP-S 考前数学练习

    [HAOI2011] 向量 首先将题目转化,转化为求方程: \(k(a,b)+q(b,a)+w(a,−b)+c(b,−a)=(x,y)\) 将这个方程再次化简,即为: \((k+w)a+(q+c)b= ...

  10. 反转字符串里的单词(leetcode 4.10每日打卡)

    给定一个字符串,逐个翻转字符串中的每个单词.   示例 1: 输入: "the sky is blue"输出: "blue is sky the" 示例 2: ...