在做机器学习项目的时候,一开始我们会将数据集分为训练集和测试集,要记住测试集只能用一次,只能用来评估最终最好的模型。如果你反复去使用测试集,反复测试后从里面挑最好的,你就是在耍流氓。

建模过程中肯定有模型调整,必然涉及到模型挑选的问题,当过程中我需要做很多个模型时,问题来了,如果我不去评估我怎么知道哪一个模型是最好的?

Typically we can’t decide on which final model to use with the test set before first assessing model performance. There is a gap between our need to measure performance reliably and the data splits (training and testing) we have available.

想想在利用测试集之前,怎么也得加上一个评估过程,帮助我们确定,到底哪个模型才是最好的,才是值得最终被用到测试集上的。

这个过程就涉及到重复抽样了resampling!

Resampling methods, such as cross-validation and the bootstrap, are empirical simulation systems. They create a series of data sets similar to the training/testing split

首先理解过拟合

写重复抽样前我们先回顾过拟合的概念,数据划分后,我们会在训练集中训练好模型,怎么评估这个模型?很自然的我可以想到,就将模型用在训练集中,将真实值和预测值对比不就好了?有文章确实是这么做的,但是现在有很多的黑箱模型几乎可以做到完全复制出训练集,做到训练集预测无偏差,这个时候这个黑箱模型就一定好吗?

bias is the difference between the true pattern or relationships in data and the types of patterns that the model can emulate. Many black-box machine learning models have low bias, meaning they can reproduce complex relationships. Other models (such as linear/logistic regression, discriminant analysis, and others) are not as adaptable and are considered high bias models

不一定的。举个实际例子吧。

对于同一个数据集,我做了两个模型,一个线性回归lm_fit,另外一个随机森林rf_fit,在训练集中他们的表现如下:

看上图,明显从rmse和rsq这两个指标看,都提示随机森林模型在训练集中表现更好。按照上面的逻辑怎么说我都应该选择随机森林模型才对。

于是我真的认为随机森林模型优于线性回归模型,然后我将随机森林模型用在了测试集中去最终评估模型表现,得到结果如下。

结果显示rmse相对于训练集从0.03一下跑到了0.07,r方也有明显下降。

到这,按照原来的思路,其实我的工作已经完了,我就单纯地认为确实我选随机森林是对的,模型的预测能力确实也只能这样了。

不妨在多做一步。

虽然刚刚说线性模型不如随机森林模型,但是我又好奇这个模型在陌生的测试集中表现究竟怎样?于是我又多做一步,把我们抛弃的线性模型用在测试集中看看表现:

可以看到线性模型在训练集和测试集中的表现一致性非常强,在测试集中的表现其实和随机森林差不太多。

上面的例子给大家的启发就是,模型训练的好(在训练集中表现好)不意味着其在测试集中也好。模型在训练集中表现好,而测试集中就不行了,就是模型过拟合的表现,模型训练时避免过拟合的,保证表现一致性的方法就是重复抽样训练。

再来看重复抽样

重复抽样训练的逻辑在于:

我们会将原来的训练集进行反复抽样形成很多和抽样样本。

对于每一个抽样样本,又会分为analysis样本集和assessment样本集,我们会在analysis样本中训练模型,然后再assessment样本中评估模型,比如我现在重复抽样20,意味着我要做20个模型,每个模型评估一次,就会评估20次,整体模型好不好,是这20次的均值说了算的。这样就大大增加了模型的推广稳健性,避免过拟合。

重复抽样的常见方法包括交叉验证和自助抽样验证,其做法代码如下:

folds <- vfold_cv(cell_train, v = 10) #交叉验证设置代码

交叉验证

交叉验证属于resampling的一种方法,一个简单的例子如下,比如我训练集30个样本,3折交叉验证的图示:

30个数据别均分为3份,每一份都当做一次assessment数据集,相应地剩下的2个数据集为analysis数据集用来训练模型

数据随机切为3份之后,每一份都会用来评估模型表现。

仔细想一下,上面的交叉验证其实还有随机性,就是你一开始就将数据切成了3份,如果只切一次其实也是有随机性的,所以我们实际使用交叉验证的时候要考虑这一点,我们会重复很多次,比如10折交叉验证再重复10次。这个就是反复交叉验证的思想,叫做Repeated cross-validation。这也是为什么交叉验证函数都会有一个repeats参数的原因。

自助法Bootstrapping

Bootstrap本身是一种确定统计量的样本分布的方法,上篇文章刚刚提到过哈

Bootstrap resampling was originally invented as a method for approximating the sampling distribution of statistics whose theoretical properties are intractable

在机器学习中,我们对训练集进行自助抽样就是在训练集中有放回地随机抽一个和训练集一样大的样本。同样的,我们还是看一个30个样本的训练集的自助抽样例子:

可以看到,我们对原始30个训练集样本进行了3次自助抽样,每次抽出来的30个样本都是有重复的,比如在第一次的时候8这个样本就重复了,而2这个样本没抽到。这样我们就让自助样本做训练,没抽到的样本做assessment set。没抽到的样本也叫做out-of-bag sample。论文中的out-of-bag验证就是指的这个意思。

滚动抽样

对于时间依赖的数据,比如面板数据,我们再考虑抽样的时候一定要将时间的先后顺序考虑进去,这时候我们用到的方法叫做Rolling forecast origin resampling:下面是这个方法的图示:

可以看到我们的抽样是按时间前进的,保证每次我们都是用老数据训练,新数据评估。上面的示例是每次丢掉一个样本,前进一个样本,实际使用的时候我们可以不丢掉,一次前进多个。

理解随机抽样的地位

上面又再次回忆了不同的重复抽样的方法,始终需要记得的是,重复抽样是服务于发现最优模型的,服务于减少欠拟合和过拟合的(很多同学做预测模型其实是略过这一步的,只能说不完美,不能说错),使用重复抽样我们会在每一个样本集中训练模型并对其进行评估,比如我某种抽样方法抽出20个样本集那么我就训练并评估模型20次,最终20个模型的平均表现作为该模型的表现。通过这么样的方式尽最大努力使得用到测试集中进行测试的模型是最优的,保证测试集只用一次并且这一次确实反映了最优模型的表现。

This sequence repeats for every resample. If there are B resamples, there are B replicates of each of the performance metrics. The final resampling estimate is the average of these B statistics. If B = 1, as with a validation set, the individual statistics represent overall performance.

这个方法怎么用呢?tidymodels给了我们相应的使用界面:

model_spec %>% fit_resamples(formula,  resamples, ...)
model_spec %>% fit_resamples(recipe, resamples, ...)
workflow %>% fit_resamples( resamples, ...)

如果你看不懂上面的界面,之后我会专门写tidymodels框架给大家,请持续关注。

R机器学习:重复抽样在机器学习模型建立过程中的地位理解的更多相关文章

  1. vs2013在使用ef6时,创建模型向导过程中,四种模型方式缺少2种

    下载eftool,并安装 https://download.microsoft.com/download/2/C/F/2CF7AFAB-4068-4DAB-88C6-CEFD770FAECD/EFTo ...

  2. TensorFlow之tf.nn.dropout():防止模型训练过程中的过拟合问题

    一:适用范围: tf.nn.dropout是TensorFlow里面为了防止或减轻过拟合而使用的函数,它一般用在全连接层 二:原理: dropout就是在不同的训练过程中随机扔掉一部分神经元.也就是让 ...

  3. 字典转模型的过程中,空值和id特殊字符的处理

    在IOS 中id是特殊字符,可是非常多时候从网络中下载的数据是以id保存的 假设在定义属性的时候 @property(nonatomic, copy) NSString *id; 就不会出现错误 当键 ...

  4. <转>机器学习系列(9)_机器学习算法一览(附Python和R代码)

    转自http://blog.csdn.net/han_xiaoyang/article/details/51191386 – 谷歌的无人车和机器人得到了很多关注,但我们真正的未来却在于能够使电脑变得更 ...

  5. 机器学习技法课之Aggregation模型

    Courses上台湾大学林轩田老师的机器学习技法课之Aggregation 模型学习笔记. 混合(blending) 本笔记是Course上台湾大学林轩田老师的<机器学习技法课>的学习笔记 ...

  6. Stanford机器学习---第七讲. 机器学习系统设计

    原文:http://blog.csdn.net/abcjennifer/article/details/7834256 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...

  7. 【机器学习】搞清楚机器学习的TP、FN、FP、TN,查全率和查准率,PR曲线和ROC曲线的含义与关系

    最近重新学习了一下机器学习的一些基础知识,这里对性能度量涉及到的各种值与图像做一个总结. 西瓜书里的这一部分讲的比较快,这些概念个人感觉非常绕,推敲了半天才搞清楚. 这些概念分别是:TP.FN.FP. ...

  8. 机器学习入门18 - 生产机器学习系统(Production ML Systems)

    除了实现机器学习算法之外,机器学习还包含许多其他内容.生产环境机器学习系统包含大量组件.无需自行构建所有内容,而是应该尽可能重复使用常规机器学习系统组件.通过了解机器学习系统的一些范例及其要求,可以明 ...

  9. 一书吃透机器学习!新版《机器学习基础》来了,教材PDF、PPT可下载 | 资源

    不出家门,也能学习到国外高校的研究生机器学习课程了. 今天,一本名为Foundations of Machine Learning(<机器学习基础>)的课在Reddit上热度飙升至300, ...

  10. 机器学习实战基础(十七):sklearn中的数据预处理和特征工程(十)特征选择 之 Embedded嵌入法

    Embedded嵌入法 嵌入法是一种让算法自己决定使用哪些特征的方法,即特征选择和算法训练同时进行.在使用嵌入法时,我们先使用某些机器学习的算法和模型进行训练,得到各个特征的权值系数,根据权值系数从大 ...

随机推荐

  1. Nuxt.js 应用中的 app:redirected 钩子详解

    title: Nuxt.js 应用中的 app:redirected 钩子详解 date: 2024/10/3 updated: 2024/10/3 author: cmdragon excerpt: ...

  2. springboot的启动类必须和controller在同一层级

    springboot的启动类必须和controller在同一层级

  3. go 实现sse

    package chat import ( "encoding/json" "github.com/zeromicro/go-zero/core/logx" & ...

  4. 11-react使用props.children 处理父子组件之间的传值

    // props.children 组件传值 import { Component } from "react" import reactDom from "react- ...

  5. 怎么理解vue的单向数据流

    单向数据流是父组件传给子组件的数据,子组件没有权利修改,只能委托父组件修改,然后子组件更新

  6. 18 . 介绍一下 Promise

    Promise 是js内置的构造函数,也叫做期约函数 ,它有 3 种状态 ,等待状态 pending ,成功状态 fullfilled ,失败状态 reject :2 个过程, 等待状态到成功状态 会 ...

  7. 再见 Dockerfile,拥抱新型镜像构建技术 Buildpacks

    作者:米开朗基杨,方阗 云原生正在吞并软件世界,容器改变了传统的应用开发模式,如今研发人员不仅要构建应用,还要使用 Dockerfile 来完成应用的容器化,将应用及其依赖关系打包,从而获得更可靠的产 ...

  8. es之增删改查

    查询 index: GET task_results/_search/ 普通查询: {"query":{"bool":{"must":[{& ...

  9. 数据运算中关于字符串""的拼接问题

    例子中准备了3种类型数据,分别针对是否在运算存在空字符串参与运算进行了演示,结果如下: 1 int x = 10; 2 double y = 20.2; 3 long z = 10L; 4 Syste ...

  10. MySql5.7及以上 ORDER BY 报错问题

    一.问题 本人使用的MySql版本是8.0的 当MySql5.7及以上的版本执行带有 ORDER BY 的SQL语句时可能会报错. 例如,执行以下mysql语句: SELECT id, user_id ...