在做机器学习项目的时候,一开始我们会将数据集分为训练集和测试集,要记住测试集只能用一次,只能用来评估最终最好的模型。如果你反复去使用测试集,反复测试后从里面挑最好的,你就是在耍流氓。

建模过程中肯定有模型调整,必然涉及到模型挑选的问题,当过程中我需要做很多个模型时,问题来了,如果我不去评估我怎么知道哪一个模型是最好的?

Typically we can’t decide on which final model to use with the test set before first assessing model performance. There is a gap between our need to measure performance reliably and the data splits (training and testing) we have available.

想想在利用测试集之前,怎么也得加上一个评估过程,帮助我们确定,到底哪个模型才是最好的,才是值得最终被用到测试集上的。

这个过程就涉及到重复抽样了resampling!

Resampling methods, such as cross-validation and the bootstrap, are empirical simulation systems. They create a series of data sets similar to the training/testing split

首先理解过拟合

写重复抽样前我们先回顾过拟合的概念,数据划分后,我们会在训练集中训练好模型,怎么评估这个模型?很自然的我可以想到,就将模型用在训练集中,将真实值和预测值对比不就好了?有文章确实是这么做的,但是现在有很多的黑箱模型几乎可以做到完全复制出训练集,做到训练集预测无偏差,这个时候这个黑箱模型就一定好吗?

bias is the difference between the true pattern or relationships in data and the types of patterns that the model can emulate. Many black-box machine learning models have low bias, meaning they can reproduce complex relationships. Other models (such as linear/logistic regression, discriminant analysis, and others) are not as adaptable and are considered high bias models

不一定的。举个实际例子吧。

对于同一个数据集,我做了两个模型,一个线性回归lm_fit,另外一个随机森林rf_fit,在训练集中他们的表现如下:

看上图,明显从rmse和rsq这两个指标看,都提示随机森林模型在训练集中表现更好。按照上面的逻辑怎么说我都应该选择随机森林模型才对。

于是我真的认为随机森林模型优于线性回归模型,然后我将随机森林模型用在了测试集中去最终评估模型表现,得到结果如下。

结果显示rmse相对于训练集从0.03一下跑到了0.07,r方也有明显下降。

到这,按照原来的思路,其实我的工作已经完了,我就单纯地认为确实我选随机森林是对的,模型的预测能力确实也只能这样了。

不妨在多做一步。

虽然刚刚说线性模型不如随机森林模型,但是我又好奇这个模型在陌生的测试集中表现究竟怎样?于是我又多做一步,把我们抛弃的线性模型用在测试集中看看表现:

可以看到线性模型在训练集和测试集中的表现一致性非常强,在测试集中的表现其实和随机森林差不太多。

上面的例子给大家的启发就是,模型训练的好(在训练集中表现好)不意味着其在测试集中也好。模型在训练集中表现好,而测试集中就不行了,就是模型过拟合的表现,模型训练时避免过拟合的,保证表现一致性的方法就是重复抽样训练。

再来看重复抽样

重复抽样训练的逻辑在于:

我们会将原来的训练集进行反复抽样形成很多和抽样样本。

对于每一个抽样样本,又会分为analysis样本集和assessment样本集,我们会在analysis样本中训练模型,然后再assessment样本中评估模型,比如我现在重复抽样20,意味着我要做20个模型,每个模型评估一次,就会评估20次,整体模型好不好,是这20次的均值说了算的。这样就大大增加了模型的推广稳健性,避免过拟合。

重复抽样的常见方法包括交叉验证和自助抽样验证,其做法代码如下:

folds <- vfold_cv(cell_train, v = 10) #交叉验证设置代码

交叉验证

交叉验证属于resampling的一种方法,一个简单的例子如下,比如我训练集30个样本,3折交叉验证的图示:

30个数据别均分为3份,每一份都当做一次assessment数据集,相应地剩下的2个数据集为analysis数据集用来训练模型

数据随机切为3份之后,每一份都会用来评估模型表现。

仔细想一下,上面的交叉验证其实还有随机性,就是你一开始就将数据切成了3份,如果只切一次其实也是有随机性的,所以我们实际使用交叉验证的时候要考虑这一点,我们会重复很多次,比如10折交叉验证再重复10次。这个就是反复交叉验证的思想,叫做Repeated cross-validation。这也是为什么交叉验证函数都会有一个repeats参数的原因。

自助法Bootstrapping

Bootstrap本身是一种确定统计量的样本分布的方法,上篇文章刚刚提到过哈

Bootstrap resampling was originally invented as a method for approximating the sampling distribution of statistics whose theoretical properties are intractable

在机器学习中,我们对训练集进行自助抽样就是在训练集中有放回地随机抽一个和训练集一样大的样本。同样的,我们还是看一个30个样本的训练集的自助抽样例子:

可以看到,我们对原始30个训练集样本进行了3次自助抽样,每次抽出来的30个样本都是有重复的,比如在第一次的时候8这个样本就重复了,而2这个样本没抽到。这样我们就让自助样本做训练,没抽到的样本做assessment set。没抽到的样本也叫做out-of-bag sample。论文中的out-of-bag验证就是指的这个意思。

滚动抽样

对于时间依赖的数据,比如面板数据,我们再考虑抽样的时候一定要将时间的先后顺序考虑进去,这时候我们用到的方法叫做Rolling forecast origin resampling:下面是这个方法的图示:

可以看到我们的抽样是按时间前进的,保证每次我们都是用老数据训练,新数据评估。上面的示例是每次丢掉一个样本,前进一个样本,实际使用的时候我们可以不丢掉,一次前进多个。

理解随机抽样的地位

上面又再次回忆了不同的重复抽样的方法,始终需要记得的是,重复抽样是服务于发现最优模型的,服务于减少欠拟合和过拟合的(很多同学做预测模型其实是略过这一步的,只能说不完美,不能说错),使用重复抽样我们会在每一个样本集中训练模型并对其进行评估,比如我某种抽样方法抽出20个样本集那么我就训练并评估模型20次,最终20个模型的平均表现作为该模型的表现。通过这么样的方式尽最大努力使得用到测试集中进行测试的模型是最优的,保证测试集只用一次并且这一次确实反映了最优模型的表现。

This sequence repeats for every resample. If there are B resamples, there are B replicates of each of the performance metrics. The final resampling estimate is the average of these B statistics. If B = 1, as with a validation set, the individual statistics represent overall performance.

这个方法怎么用呢?tidymodels给了我们相应的使用界面:

model_spec %>% fit_resamples(formula,  resamples, ...)
model_spec %>% fit_resamples(recipe, resamples, ...)
workflow %>% fit_resamples( resamples, ...)

如果你看不懂上面的界面,之后我会专门写tidymodels框架给大家,请持续关注。

R机器学习:重复抽样在机器学习模型建立过程中的地位理解的更多相关文章

  1. vs2013在使用ef6时,创建模型向导过程中,四种模型方式缺少2种

    下载eftool,并安装 https://download.microsoft.com/download/2/C/F/2CF7AFAB-4068-4DAB-88C6-CEFD770FAECD/EFTo ...

  2. TensorFlow之tf.nn.dropout():防止模型训练过程中的过拟合问题

    一:适用范围: tf.nn.dropout是TensorFlow里面为了防止或减轻过拟合而使用的函数,它一般用在全连接层 二:原理: dropout就是在不同的训练过程中随机扔掉一部分神经元.也就是让 ...

  3. 字典转模型的过程中,空值和id特殊字符的处理

    在IOS 中id是特殊字符,可是非常多时候从网络中下载的数据是以id保存的 假设在定义属性的时候 @property(nonatomic, copy) NSString *id; 就不会出现错误 当键 ...

  4. <转>机器学习系列(9)_机器学习算法一览(附Python和R代码)

    转自http://blog.csdn.net/han_xiaoyang/article/details/51191386 – 谷歌的无人车和机器人得到了很多关注,但我们真正的未来却在于能够使电脑变得更 ...

  5. 机器学习技法课之Aggregation模型

    Courses上台湾大学林轩田老师的机器学习技法课之Aggregation 模型学习笔记. 混合(blending) 本笔记是Course上台湾大学林轩田老师的<机器学习技法课>的学习笔记 ...

  6. Stanford机器学习---第七讲. 机器学习系统设计

    原文:http://blog.csdn.net/abcjennifer/article/details/7834256 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...

  7. 【机器学习】搞清楚机器学习的TP、FN、FP、TN,查全率和查准率,PR曲线和ROC曲线的含义与关系

    最近重新学习了一下机器学习的一些基础知识,这里对性能度量涉及到的各种值与图像做一个总结. 西瓜书里的这一部分讲的比较快,这些概念个人感觉非常绕,推敲了半天才搞清楚. 这些概念分别是:TP.FN.FP. ...

  8. 机器学习入门18 - 生产机器学习系统(Production ML Systems)

    除了实现机器学习算法之外,机器学习还包含许多其他内容.生产环境机器学习系统包含大量组件.无需自行构建所有内容,而是应该尽可能重复使用常规机器学习系统组件.通过了解机器学习系统的一些范例及其要求,可以明 ...

  9. 一书吃透机器学习!新版《机器学习基础》来了,教材PDF、PPT可下载 | 资源

    不出家门,也能学习到国外高校的研究生机器学习课程了. 今天,一本名为Foundations of Machine Learning(<机器学习基础>)的课在Reddit上热度飙升至300, ...

  10. 机器学习实战基础(十七):sklearn中的数据预处理和特征工程(十)特征选择 之 Embedded嵌入法

    Embedded嵌入法 嵌入法是一种让算法自己决定使用哪些特征的方法,即特征选择和算法训练同时进行.在使用嵌入法时,我们先使用某些机器学习的算法和模型进行训练,得到各个特征的权值系数,根据权值系数从大 ...

随机推荐

  1. 主要将子文件下大量图片进行路径编号,并保存到csv文件当中。方便直接从文件读取图片路径以及其他图片信息

    # coding: utf-8 #主要将子文件下大量图片进行路径编号,并保存到csv文件当中.方便直接从文件读取图片路径以及其他图片信息. #我做的是图像分割,所以存在三类分割区域:["la ...

  2. python:将文件从一个目录移动到另一个目录。附:nnUnet使用

    在使用nn-Unet做BraTS2019数据集预测时,预测文件分别生成了三类文件:.pkl  .npz  .nii.gz,我们需要的是.nii.gz文件.所以需要进行文件移动. # coding:ut ...

  3. ssr屏幕空间射线追踪

    本轮作业中,我们需要在一个光源为方向光,材质为漫反射 (Diffuse) 的场景 中,完成屏幕空间下的全局光照效果(两次反射). 为了在作业框架中实现上述效果,基于我们需要的信息不同我们会分三阶段 着 ...

  4. element输入天数,获取当前时间加上天数 【时间获取】

    handleInput (val) { // console.log(this.formModel.ITEM_PM) if (!(/[^\d]/g).test(val)) { // console.l ...

  5. [NOI Online 2022 入门组] 数学游戏

    P8255 [NOI Online 2022 入门组] 数学游戏 注:妙哉,此题可以理解为数学题. 思路 由题易得: \[\notag z=d_x\times d_y\times \gcd(x,y)^ ...

  6. Iterator和Iterable

    Java遍历List有三种方式 public static void main(String[] args) { List<String> list = new ArrayList< ...

  7. 将NC栅格表示时间维度的数据提取出来的方法

      本文介绍基于Python语言,逐一读取大量.nc格式的多时相栅格文件,导出其中所具有的全部时间信息的方法.   .nc是NetCDF(Network Common Data Form)文件的扩展名 ...

  8. Matrix Calculus

    1 Scalar Function \(\text{If }f(\mathbf{x})\in\mathbf{R},\mathrm{then}\) \[df=\frac{\partial f}{\par ...

  9. EXCEL获取拼音首字母

    Excel 2016 按组合键ALT+F11调出VB窗口--插入--模块(复制代码到新模块中,复制完后始可关闭VB窗口) 复制以下代码到模块中 Function getpychar(char) tmp ...

  10. 《一篇就够系列》之HTTP详解,覆盖高频面试考点!

    一.写在开头 前几篇博文大概介绍了什么是网络编程,以及网络编程的实战作用,今日起,我们将针对里面涉及到的重要知识点,进行详细的梳理与学习! 在整个WEB编程中,有个应用层的协议是我们无法跳过的,那就是 ...