【笔记】CART与决策树中的超参数
CART与决策树中的超参数
先前的决策树其实应该称为CART
CART的英文是Classification and regression tree,全称为分类与回归树,其是在给定输入随机变量X条件下输出随机变量Y的条件概率分布的学习方法,就是假设决策树是二叉树,内部结点特征的取值为“是”和“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支,其可以解决分类问题,又可以解决回归问题,特点就是根据某一个维度d和某一个阈值v进行二分
在sklearn中的决策树都是CART的方式实现的
回顾前面,不难发现,对于之前的决策树的划分模拟,平均而言,预测的复杂度是O(logm),其中m为样本个数,这样创建决策树的训练的过程的复杂度是O(nmlogm),是很高的,n为维度数,还有一个重要的问题就是很容易产生过拟合
那么因为种种原因,在创建决策树的时候需要进行剪枝,即降低复杂度,解决过拟合的操作
剪枝的操作有很多种,实际上就是对各种参数进行平衡
通过具体操作来体现一下
(在notebook中)
加载好需要的包,使用make_moons生成虚拟数据,将数据情况绘制出来
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
X,y = datasets.make_moons(noise=0.25,random_state=666)
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下

然后使用DecisionTreeClassifier这个类,不进行任何限定操作,进行实例化操作以后训练数据
from sklearn.tree import DecisionTreeClassifier
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X,y)
绘制函数,绘制图像
from matplotlib.colors import ListedColormap
def plot_decision_boundary(model, axis):
x0,x1 = np.meshgrid(
np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
)
X_new = np.c_[x0.ravel(),x1.ravel()]
y_predict = model.predict(X_new)
zz = y_predict.reshape(x0.shape)
custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D', '#90CAF9'])
plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
plot_decision_boundary(dt_clf,axis=[-1.5,2.5,-1.0,1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下(可以发现决策边界很不规则,很明显产生了过拟合)

设置参数max_depth为2,限制最大深度为2,然后训练数据并绘制模型
dt_clf2 = DecisionTreeClassifier(max_depth=2)
dt_clf2.fit(X,y)
plot_decision_boundary(dt_clf2,axis=[-1.5,2.5,-1.0,1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下(可以发现过拟合已经不明显了,很清晰的表示出了边界)

还可以设置参数min_samples_split为10,限制一个节点的样本数最小为10,然后训练数据并绘制模型
dt_clf3 = DecisionTreeClassifier(min_samples_split=10)
dt_clf3.fit(X,y)
plot_decision_boundary(dt_clf3,axis=[-1.5,2.5,-1.0,1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下(过拟合程度低)

还可以设置参数min_samples_leaf为6,限制一个叶子节点样本数至少为6,然后训练数据并绘制模型
dt_clf4 = DecisionTreeClassifier(min_samples_leaf=6)
dt_clf4.fit(X,y)
plot_decision_boundary(dt_clf4,axis=[-1.5,2.5,-1.0,1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下

还可以设置参数max_leaf_nodes为4,限制叶子节点最多为4,然后训练数据并绘制模型
dt_clf5 = DecisionTreeClassifier(max_leaf_nodes=4)
dt_clf5.fit(X,y)
plot_decision_boundary(dt_clf5,axis=[-1.5,2.5,-1.0,1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
图像如下

可以发现,对参数进行适当的修改可以很好的解决过拟合的问题,但是需要注意不要调节到欠拟合,那么寻找到合适的参数就可以使用网格搜索的方式

【笔记】CART与决策树中的超参数的更多相关文章
- 机器学习:决策树(CART 、决策树中的超参数)
老师:非参数学习的算法都容易产生过拟合: 一.决策树模型的创建方式.时间复杂度 1)创建方式 决策树算法 既可以解决分类问题,又可以解决回归问题: CART 创建决策树的方式:根据某一维度 d 和某一 ...
- DeepLearning.ai学习笔记(二)改善深层神经网络:超参数调试、正则化以及优化--Week2优化算法
1. Mini-batch梯度下降法 介绍 假设我们的数据量非常多,达到了500万以上,那么此时如果按照传统的梯度下降算法,那么训练模型所花费的时间将非常巨大,所以我们对数据做如下处理: 如图所示,我 ...
- 【笔记】KNN之网格搜索与k近邻算法中更多超参数
网格搜索与k近邻算法中更多超参数 网格搜索与k近邻算法中更多超参数 网络搜索 前笔记中使用的for循环进行的网格搜索的方式,我们可以发现不同的超参数之间是存在一种依赖关系的,像是p这个超参数,只有在 ...
- Coursera Deep Learning笔记 改善深层神经网络:超参数调试 正则化以及梯度相关
笔记:Andrew Ng's Deeping Learning视频 参考:https://xienaoban.github.io/posts/41302.html 参考:https://blog.cs ...
- deeplearning.ai 改善深层神经网络 week3 超参数调试、Batch正则化和程序框架 听课笔记
这一周的主体是调参. 1. 超参数:No. 1最重要,No. 2其次,No. 3其次次. No. 1学习率α:最重要的参数.在log取值空间随机采样.例如取值范围是[0.001, 1],r = -4* ...
- [DeeplearningAI笔记]02_3.1-3.2超参数搜索技巧与对数标尺
Hyperparameter search 超参数搜索 觉得有用的话,欢迎一起讨论相互学习~Follow Me 3.1 调试处理 需要调节的参数 级别一:\(\alpha\)学习率是最重要的需要调节的 ...
- Deep Learning.ai学习笔记_第二门课_改善深层神经网络:超参数调试、正则化以及优化
目录 第一周(深度学习的实践层面) 第二周(优化算法) 第三周(超参数调试.Batch正则化和程序框架) 目标: 如何有效运作神经网络,内容涉及超参数调优,如何构建数据,以及如何确保优化算法快速运行, ...
- ng-深度学习-课程笔记-8: 超参数调试,Batch正则(Week3)
1 调试处理( tuning process ) 如下图所示,ng认为学习速率α是需要调试的最重要的超参数. 其次重要的是momentum算法的β参数(一般设为0.9),隐藏单元数和mini-batc ...
- 【笔记】KNN之超参数
超参数 超参数 很多时候,对于算法来说,关于这个传入的参数,传什么样的值是最好的? 这就涉及到了机器学习领域的超参数 超参数简单来说就是在我们运行机器学习之前用来指定的那个参数,就是在算法运行前需要决 ...
随机推荐
- 被swoole坑哭的PHP程序员 (转)
本文主要记录一下学习swoole的过程.填过的坑以及swoole究竟有多么强大! 首先说一下对swoole的理解:披着PHP外衣的C程序.很多PHPer朋友看到swoole提供的强大功能.外界对其的崇 ...
- XCTF Guess-the-Number
一.发现是jar文件,一定想到反汇编gdui这个工具,而且运行不起来,可能是我电脑问题,我就直接反编译出来了. 也发现了flag,和对应的算法,直接拉出来,在本地运行,后得到flag 二.java代码
- C#下通过wbemtest和WMI Code Cretor更加高效的访问WMI
能找到这篇博客的,相信都是有操作WMI需求的了.但是不知道如何快速验证.并集成到C#来操作WMI.我们分为3步: ##第一步:官网(或跟硬件开发WMI的人沟通你需要的接口和参数定义,如果是和硬件开发的 ...
- Java | 字符串的使用 & 分析
字符串 字符串广泛应用 在 Java 编程中,在 Java 中字符串属于对象,在程序中所有的双引号字符串,都是String类的对象. 字符串的特点 1.字符串的内容永不可变. 2.正在是因为字符串的不 ...
- 「AGC032E」 Modulo Pairing
「AGC032E」 Modulo Pairing 传送门 如果所有数都 \(<\lfloor \frac m 2\rfloor\),一个自然的想法是对所有数排序过后大小搭配,这样显然是最优秀的. ...
- c语言字符串占据字节数
# include <stdio.h> //字符串占据的字节数 /* 不能将一个字符串常量赋给一个字符变量 为什么不能将一个字符串常量赋给一个字符变量?可以从两个方面作出解释: 前面讲过, ...
- Qt5MV自定义模型与实例浅析
1. Model/View结构 这种结构,其实就是将界面组件与所编辑的数据分离开来,又通过数据源的方式连接起来,相当于解耦,视图层只关心显示和与用户交互,而数据层负责与实际的数据进行通信,并为视图组件 ...
- urllib库中的URL编码解码和GETPOST请求
在urllib库的使用过程中,会在请求发送之前按照发送请求的方式进行编码处理,来使得传递的参数更加的安全,也更加符合模拟浏览器发送请求的形式.这就需要用urllib中的parse模块.parse的使用 ...
- Redis学习——常用小功能
一.慢查询分析(查询日志:所谓慢查询日志就是系统在命令执行前后计算每条命令的执行时间,当超过预设阀值,就将这条命令的相关信息(例如:发生时间,耗时,命令的详细信息)记录下来,Redis也提供了类似的功 ...
- grub2引导各种ISO系统镜像
1 grub2引导winpe 1下载winpe.iso镜像,放置在U盘的根目录下 2下载syslinux的memdisk,放置在U盘的根目录下 3 set root=(hd0,msdos1) 设置U ...