更新、更全的《机器学习》的更新网站,更有python、go、数据结构与算法、爬虫、人工智能教学等着你:https://www.cnblogs.com/nickchen121/p/11686958.html

scikit-learn库之决策树

在scikit-learn库中决策树使用的CART算法,因此该决策树既可以解决回归问题又可以解决分类问题,即下面即将讲的DecisionTreeClassifierDecisionTreeRegressor两个模型。

接下来将会讨论这两者的区别,由于是从官方文档翻译而来,翻译会略有偏颇,有兴趣的也可以去scikit-learn官方文档查看https://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree

一、DecisionTreeClassifier

1.1 使用场景

DecisionTreeClassifier模型即CART算法实现的决策树,通常用于解决分类问题。

1.2 代码

from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, [2, 3]]
y = iris.target
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=0,
splitter='best')
cross_val_score(clf, iris.data, iris.target, cv=10)
array([1.        , 0.93333333, 1.        , 0.93333333, 0.93333333,
0.86666667, 0.93333333, 1. , 1. , 1. ])

1.3 参数详解

  • criterion:特征选择,str类型。criterion='gini'表示基尼指数;criterion='entropy'表示信息增益,推荐使用'gini'。默认为'gini'。
  • splitter:特征划分点选择,str类型。splitter='best'在特征的所有划分点中找出最优的划分点,适合小样本量;splitter='random'随机的在部分划分点中找到局部最优的划分点,适合大样本量。默认为'best'。
  • max_depth:最大深度,int类型。如果样本特征数较少可以使用默认值,如果样本特征数较多一般用max_depty=10-100限制树的最大深度。默认为None。
  • min_samples_split:内部节点划分需要最少样本数,float类型。限定子树继续划分的条件,如果某节点的样本数少于min_samples_split,则会停止继续划分子树。如果样本数量过大,建议增大该值,否则建议使用默认值。默认为2。
  • min_samples_leaf:叶子节点最少样本数float类型。如果在某次划分叶子节点数目小于样本数,则会和兄弟节点一起剪枝。如果样本数量过大,建议增大该值,否则建议使用默认值。默认为1。
  • min_weight_fraction_leaf:叶子节点最小的样本权重和,float类型。该参数限制了叶子节点所有样本权重和的最小值,如果小于该值,则会和兄弟节点一起剪枝。如果样本有角度的缺失值,或者样本的分布偏差较大,则可以考虑权重问题。默认为0。
  • max_features:划分的最大特征数,str、int、float类型。max_depth='log2'表示最多考虑\(log_2n\)个特征;max_depth={'sqrt','auto'}表示最多考虑\(\sqrt{n}\)个特征;max_depth=int类型,考虑\(|int类型|\)个特征;max_depth=float类型,如0.3,则考虑\(0.3n\)个特征,其中\(n\)为样本总特征数。默认为None,样本特征数不大于50推荐使用默认值。
  • random_state:随机数种子,int类型。random_state=None,不同时刻产生的随机数据是不同的;random_state=int类型,相同随机数种子不同时刻产生的随机数是相同的。默认为None。
  • max_leaf_nodes:最大叶子节点数,int类型。限制最大叶子节点数,可以防止树过深,因此可以防止过拟合。默认为None。
  • min_impurity_decrease:节点减小不纯度,float类型。如果某节点划分会导致不纯度的减少大于min_impurity_decrease,则停止该节点划分。默认为0。
  • min_impurity_split:节点划分最小不纯度,float类型。如果某节点的不纯度小于min_impurity_split,则停止该节点划分,即不生成叶子节点。默认为1e-7(0.0000001)。
  • class_weight:类别权重,dict类型或str类型。对于二元分类问题可以使用class_weight={0:0.9,1:0.1},表示0类别权重为0.9,1类别权重为0.1,str类型即为'balanced',模型将根据训练集自动修改不同类别的权重。默认为None。
  • presort:数据是否排序,bool类型。样本量较小,presort=True,即让样本数据排序,节点划分速度更快;样本量较大,presort=True,让样本排序反而会增加训练模型的时间。通常使用默认值。默认值为False。

1.4 属性

  • classes_:array类型。样本的类别标签列表。
  • max_features_:int类型。最大的特征的推断值。
  • n_classes_:int类型。fit之后训练集的类别数量。
  • n_features_:int类型。fit之后训练集的特征数。
  • n_outputs_:int类型。fit之后训练集的输出数量。
  • tree_:Tree object类型。返回树结构对象地址。

1.5 方法

  • apply(X[, check_input]):返回每个样本预测的叶子节点索引。
  • decision_path(X[, check_input]):返回样本X在树中的决策路径。
  • fit(X,y):把数据放入模型中训练模型。
  • get_params([deep]):返回模型的参数,可以用于Pipeline中。
  • predict(X):预测样本X的分类类别。
  • predict_log_proba(X):返回样本X在各个类别上对应的对数概率。
  • predict_proba(X):返回样本X在各个类别上对应的概率。
  • score(X,y[,sample_weight]):基于报告决定系数\(R^2\)评估模型。
  • set_prams(**params):创建模型参数。

二、DecisionTreeRegressor

DecisionTreeRegressor即CART回归树,它与DecisionTreeClassifier模型的区别在于criterion特征选择标准与分类树不同,它可以选择'mse'和'mae',前者是均方误差,后者是绝对值误差,一般而言'mse'比'mae'更准确。

02-25 scikit-learn库之决策树的更多相关文章

  1. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  2. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  3. scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类 (python代码)

    scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups #-*- coding: UTF-8 -*- import ...

  4. 【网络爬虫入门02】HTTP客户端库Requests的基本原理与基础应用

    [网络爬虫入门02]HTTP客户端库Requests的基本原理与基础应用 广东职业技术学院  欧浩源 1.引言 实现网络爬虫的第一步就是要建立网络连接并向服务器或网页等网络资源发起请求.urllib是 ...

  5. (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探

    目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...

  6. Scikit Learn

    Scikit Learn Scikit-Learn简称sklearn,基于 Python 语言的,简单高效的数据挖掘和数据分析工具,建立在 NumPy,SciPy 和 matplotlib 上.

  7. Python第三方库(模块)"scikit learn"以及其他库的安装

    scikit-learn是一个用于机器学习的 Python 模块. 其主页:http://scikit-learn.org/stable/. GitHub地址: https://github.com/ ...

  8. Sklearn库例子——决策树分类

    Sklearn上关于决策树算法使用的介绍:http://scikit-learn.org/stable/modules/tree.html 1.关于决策树:决策树是一个非参数的监督式学习方法,主要用于 ...

  9. Query意图分析:记一次完整的机器学习过程(scikit learn library学习笔记)

    所谓学习问题,是指观察由n个样本组成的集合,并根据这些数据来预测未知数据的性质. 学习任务(一个二分类问题): 区分一个普通的互联网检索Query是否具有某个垂直领域的意图.假设现在有一个O2O领域的 ...

随机推荐

  1. 通过原型继承理解ES6 extends 如何实现继承

    前言 第一次接触到 ES6 中的 class 和 extends 时,就听人说这两个关键字不过是语法糖而已.它们的本质还是 ES3 的构造函数,原型链那些东西,没有什么新鲜的,只要理解了原型链等这些概 ...

  2. 阿里短信封装SDK TP3.2

    1.阿里短信接口需要企业认证: 2.短信需要短信模板 <?php /** * 阿里云短信验证码发送类 * @param string $accessKeyId key * @param stri ...

  3. Python作业本——第3章 函数

    今天看完了第三章,习题都是一些概念性的问题,就不一一解答了. 实践项目是创建一个Collatz序列,并且加上验证输入是不是一个整数. def collatz(number): if number % ...

  4. 个人IP「Android大强哥」上线啦!

    自从入职新公司之后就一直忙得不行,一边熟悉开发的流程,一边熟悉各种网站工具的使用,一边又在熟悉业务代码,好长时间都没有更文了. 不过新公司的 mentor(导师)还是很不错的,教给我很多东西,让我也能 ...

  5. 微信小程序集成腾讯云 IM SDK

    微信小程序集成腾讯云 IM SDK 1.背景 因业务功能需求需要接入IM(即时聊天)功能,一开始想到的是使用 WebSocket 来实现这个功能,然天意捉弄(哈哈)服务器版本太低不支持 wx 协议(也 ...

  6. 讲解开源项目:让你成为灵魂画手的 JS 引擎:Zdog

    本文作者:HelloGitHub-kalifun HelloGitHub 的<讲解开源项目>系列,项目地址:https://github.com/HelloGitHub-Team/Arti ...

  7. Unity3D_UGUI与NGUI的区别与优缺点

    1. NGUI与UGUI的区别 1) UGUI的Canvas 有世界坐标和屏幕坐标; 2) UGUI的Image可以使用material; 3) UGUI通过Mask来裁剪,而NGUI通过Panel的 ...

  8. javascript实现二叉搜索树

    在使用javascript实现基本的数据结构中,练习了好几周,对基本的数据结构如 栈.队列.链表.集合.哈希表.树.图等内容进行了总结并且写了笔记和代码. 在 github中可以看到  点击查看,可以 ...

  9. Linux 笔记 - 第十三章 Linux 系统日常管理之(三)Linux 系统日志和服务

    博客地址:http://www.moonxy.com 一.前言 日志文件记录了系统每天发生的各种各样的事情,比如监测系统状况.排查问题等.作为系统运维人员可以通过日志来检查错误发生的原因,或者受到攻击 ...

  10. Driect3D初始化演示

    初始化Direct3D演示 初始化Driect3D类: #include "Common\d3dApp.h" #include <DirectXColors.h> us ...