随机森林回归Random Forest Regression)是一种在机器学习领域广泛应用的算法,由美国科学家 Leo Breiman 在2001年提出。
它是一种集成学习方法,通过整合多个决策树的预测结果来提高预测精度和稳定性。

随机森林回归适用于各种需要预测连续数值输出的问题,
如金融领域的股票价格预测、客户信用评分,医疗领域的疾病诊断和药物发现等。

1. 算法概述

随机森林回归算法通过引入随机性来构建多个决策树,再通过对这些树的预测结果进行平均或投票来得出最终的预测结果。
这里的随机性主要体现在两个方面:一是训练样本的随机选取,二是在训练过程中特征的随机选取。

随机森林的算法过程并不复杂,主要的步骤如下:

  1. 从原始训练集中随机选择一部分样本,构成一个新的子样本集。这样可以使得每棵决策树都在不同的样本集上进行训练,增加模型的多样性。
  2. 对于每个决策树的每个节点,在选择最佳划分特征时,只考虑随机选择的一部分特征。这样可以防止某些特征对整个模型的影响过大,提高模型的鲁棒性。
  3. 在每个子样本集上使用某种决策树算法构建一棵决策树。决策树的生长过程中,通常采用递归地选择最佳划分特征,将数据集划分为不纯度最小的子集。
  4. 通过上述步骤生成的大量决策树最终组合成随机森林。

上面第一,第二步骤中的随机性就是随机森林这个名称的由来。

2. 创建样本数据

这次的回归样本数据,我们用 scikit-learn 自带的样本生成器来生成回归样本。
关于样本生成器的内容,可以参考:TODO

from sklearn.datasets import make_regression

# 回归样本生成器
X, y = make_regression(n_features=4, n_informative=2)

每个样本有4个特征。

3. 模型训练

训练之前,为了减少算法误差,先对数据进行标准化处理。

from sklearn import preprocessing as pp

# 数据标准化
X = pp.scale(X)
y = pp.scale(y)

接下来分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

然后用scikit-learn中的RandomForestRegressor模型来训练:

from sklearn.ensemble import RandomForestRegressor

# 定义随机森林回归模型
reg = RandomForestRegressor(max_depth=2) # 训练模型
reg.fit(X_train, y_train) # 在测试集上进行预测
y_pred = reg.predict(X_test)

RandomForestRegressor的主要参数包括:

  1. n_estimators:森林中决策树的数量。默认值为100,表示这是森林中树木的数量,即基评估器的数量。但是,任何模型都有决策边界,当n_estimators达到一定的程度之后,随机森林的精确性往往不再上升或开始波动。同时,n_estimators越大,需要的计算量和内存也越大,训练的时间也会越来越长。
  2. max_depth:树的最大深度。默认是None,与剪枝相关。设置为None时,树的节点会一直分裂,直到每个叶子都是“纯”的,或者叶子中包含于min_samples_split个样本。可以从3开始尝试增加,观察是否应该继续加大深度。
  3. min_samples_split:在叶节点处需要的最小样本数。默认值是2,指定每个内部节点(非叶子节点)包含的最少的样本数。
  4. min_samples_leaf:每个叶子结点包含的最少的样本数。参数的取值除了整数之外,还可以是浮点数。如果参数的值设置过小会导致过拟合,反之就会欠拟合。
  5. min_weight_fraction_leaf:叶子节点所需要的最小权值。
  6. max_features:用于限制分枝时考虑的特征个数。超过限制个数的特征都会被舍弃。此参数可以设为整数、浮点数、字符或None,默认为'auto'。
  7. max_leaf_nodes:最大叶子节点数,整数,默认为None。这个参数通过限制树的最大叶子数量来防止过拟合,如果设置了一个正整数,则会在建立的最大叶节点内的树中选择最优的决策树。
  8. min_impurity_decrease:如果分裂指标的减少量大于该值,则进行分裂。
  9. min_impurity_split:决策树生长的最小纯净度。默认是0。

最后验证模型的训练效果:

from sklearn import metrics

# 在测试集上进行预测
y_pred = reg.predict(X_test) mse, r2, m_error = 0.0, 0.0, 0.0
y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:0.0918182629293023
复相关系数:0.9137032593574914
中位数绝对误差:0.17199566634564867

从预测的误差来看,训练的效果非常好

有同样的数据试了下上一篇介绍的决策树回归算法,发现还是随机森林回归的效果要好一些。
决策数回归的模型效果:

from sklearn.tree import DecisionTreeRegressor
from sklearn import metrics # 定义决策树回归模型
reg = DecisionTreeRegressor(max_depth=2) # 训练模型
reg.fit(X_train, y_train) # 在测试集上进行预测
y_pred = reg.predict(X_test) mse, r2, m_error = 0.0, 0.0, 0.0
y_pred = reg.predict(X_test)
mse = metrics.mean_squared_error(y_test, y_pred)
r2 = metrics.r2_score(y_test, y_pred)
m_error = metrics.median_absolute_error(y_test, y_pred) print("均方误差:{}".format(mse))
print("复相关系数:{}".format(r2))
print("中位数绝对误差:{}".format(m_error)) # 运行结果
均方误差:0.1681399575883647
复相关系数:0.8419711956126009
中位数绝对误差:0.36483491370039456

从运行结果来看,决策树回归的误差比随机森林回归要大不少。

4. 总结

随机森林回归算法的优势主要在于可以有效地处理大量的输入变量,并且可以处理非线性关系和交互作用,
同时 ,由于它是集成学习方法,所以可以有效地减少过拟合和欠拟合的问题,提高预测的准确性。

此外,在训练过程中,它可以自动进行特征选择和降维,帮助找到最重要的特征,
还可以处理缺失值和异常值,不需要进行特殊的数据预处理。

然而,随机森林回归算法也有一些劣势,
首先,它的训练过程相对较慢,尤其是在数据集较大或特征维度较高的情况下;
其次,在某些情况下,它可能过于依赖输入数据的随机性,导致预测结果的不稳定。
此外,随机森林算法在处理那些需要精确控制的问题时可能效果不佳

【scikit-learn基础】--『监督学习』之 随机森林回归的更多相关文章

  1. 机器学习实战基础(三十八):随机森林 (五)RandomForestRegressor 之 用随机森林回归填补缺失值

    简介 我们从现实中收集的数据,几乎不可能是完美无缺的,往往都会有一些缺失值.面对缺失值,很多人选择的方式是直接将含有缺失值的样本删除,这是一种有效的方法,但是有时候填补缺失值会比直接丢弃样本效果更好, ...

  2. MATLAB随机森林回归模型

    MATLAB随机森林回归模型: 调用matlab自带的TreeBagger.m T=textread('E:\datasets-orreview\discretized-regression\10bi ...

  3. 机器学习之路:python 集成回归模型 随机森林回归RandomForestRegressor 极端随机森林回归ExtraTreesRegressor GradientBoostingRegressor回归 预测波士顿房价

    python3 学习机器学习api 使用了三种集成回归模型 git: https://github.com/linyi0604/MachineLearning 代码: from sklearn.dat ...

  4. Python基础『一』

    内置数据类型 数据名称 例子 数字: Bool,Complex,Float,Integer True/False; z=a+bj; 1.23; 123 字符串: String '123456' 元组: ...

  5. Python基础『二』

    目录 语句,表达式 赋值语句 打印语句 分支语句 循环语句 函数 函数的作用 函数的三要素 函数定义 DEF语句 RETURN语句 函数调用 作用域 闭包 递归函数 匿名函数 迭代 语句,表达式 赋值 ...

  6. pyspark RandomForestRegressor 随机森林回归

    #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jun 8 09:27:08 2018 ...

  7. 机器学习实战基础(三十七):随机森林 (四)之 RandomForestRegressor 重要参数,属性与接口

    RandomForestRegressor class sklearn.ensemble.RandomForestRegressor (n_estimators=’warn’, criterion=’ ...

  8. 机器学习实战基础(三十五):随机森林 (二)之 RandomForestClassifier 之重要参数

    RandomForestClassifier class sklearn.ensemble.RandomForestClassifier (n_estimators=’10’, criterion=’g ...

  9. Python机器学习笔记——随机森林算法

    随机森林算法的理论知识 随机森林是一种有监督学习算法,是以决策树为基学习器的集成学习算法.随机森林非常简单,易于实现,计算开销也很小,但是它在分类和回归上表现出非常惊人的性能,因此,随机森林被誉为“代 ...

  10. 随机森林random forest及python实现

    引言想通过随机森林来获取数据的主要特征 1.理论根据个体学习器的生成方式,目前的集成学习方法大致可分为两大类,即个体学习器之间存在强依赖关系,必须串行生成的序列化方法,以及个体学习器间不存在强依赖关系 ...

随机推荐

  1. Azure Data Factory(九)基础知识回顾

    一,引言 在本文中,我们将继续了解什么是 Azure Data Factory,Azure Data Factory 的工作原理,Azure Data Factory 数据工程中的数据管道,并了解继承 ...

  2. 《流畅的Python》 读书笔记 230926

    写在最前面的话 缘由 关于Python的资料市面上非常多,好的其实并不太多. 个人认为,基础的,下面的都还算可以 B站小甲鱼 黑马的视频 刘江的博客 廖雪峰的Python课程 进阶的更少,<流畅 ...

  3. MySQL运维2-主从复制

    一.主从复制概念 主从复制是指将主数据库的DDL和DML操作通过二进制日志传到从服务器中,然后在从服务器上对这些日志重新执行也叫重做,从而使得从数据库和主库的数据保持同步. MySQL支持一台主库同时 ...

  4. FX3U-3A-ADP模拟量和数字量之间转换

    简单的例子: 0-10V对应0-8,4-20mA对应0-30 以下是对上面例子的详解: 电压: 电压(0-10V) 0-10V对应着数字量0-4000 数字量与变频器HZ量之间的关系是(4000-0) ...

  5. windows平板的开发和选型

    今天谈一个老话题,windows系统的选型和开发.问题的起因是我们一个客户说,用安卓平板不安全,苹果系统不考虑,于是他们要用自认为安全的WIN7系统. 提到WINDOWS平台下的的平板系统,此事说来话 ...

  6. 给STM32装点中国风——华为LiteOS移植

    我都二手程序员好几个礼拜了!想给我的STM32来点"中国风",装个华为LiteOS操作系统. 在此之前,我也试过STM32CubeMX自带的FreeRTOS操作系统,不知是何缘故, ...

  7. .NET的各种对象在内存中如何布局[博文汇总]

    在过去一段时间里,我陆陆续续写一些关于.NET对象类型布局的文章,其中包括值类型和引用类型的内存布局.字符串对象和数组的内存布局等,这里作一个简单的汇总. [1] 如何计算一个实例占用多少内存? 我们 ...

  8. Util应用框架基础(七) - 缓存

    本节介绍Util应用框架如何操作缓存. 概述 缓存是提升性能的关键手段之一. 除了提升性能,缓存对系统健壮性和安全性也有影响. 不同类型的系统对缓存的依赖程度不同. 对于后台管理系统,由于是给管理人员 ...

  9. 微信小程序记住密码,让登录解放双手

    密码是用户最重要的数据,也是系统最需要保护的数据,我们在登录的时候需要用账号密码请求登录接口,如果用户勾选记住密码,那么下一次登录时,我们需要将账号密码回填到输入框,用户可以直接登录系统.我们分别对这 ...

  10. C语言二进制转换成八,十,十六进制

    代码目前只支持整数转换,小数转换后续更新呀 #include <stdio.h> #include <math.h> void B_O(int n); void B_H(int ...