如果给定的样本充足,进行模型选择的一种简单方法是随机地将数据集切分成三部分,分为训练集(training set)、验证集(validation set)和测试集(testing set)。训练集用来训练模型,验证集用于模型的选择,而测试集用于最终对学习方法评估。在学习到的不同复杂度的模型中,选择对验证集有最小预测误差的模型。由于验证集有足够多的数据,用它对模型进行选择也是有效的。

在许多实际应用中数据是不充足的,为了选择好的模型,可以采用交叉验证方法。

k折交叉验证(k-fold cross validation):首先随机地将已给数据切分为k个互不相交的大小相同的子集,然后利用k-1个子集的数据训练模型,利用余下的子集测试数据,将这一过程对可能的k中选择重复进行,最后选出k次评测中平均测试误差最小的模型。

总结:实际使用时,我们通过训练集学习到参数,再计算交叉验证集上的error,再选择一个在验证集上error最小的模型,最后再在测试集上估计模型的泛化误差。

注意k-fold cross validation的目的不是为了选择模型,而是先是有了一个模型,对这个模型进行精度评定。此处不同的模型指的“generally when we say 'a model' we refer to a particular method for describing how some input data relates to what we are trying to predict. We don't generally refer to particular instances of that method as different models. So you might say 'I have a linear regression model' but you wouldn't call two different sets of the trained coefficients different models. ”

say we have two models, say a linear regression model and a neural network. How can we say which model is better? We can do K-fold cross-validation and see which one proves better at predicting the test set points. But once we have used cross-validation to select the better performing model, we train that model (whether it be the linear regression or the neural network) on all the data. We don't use the actual model instances we trained during cross-validation for our final predictive model.


除了上述选择模型的功能之外,validation set还用来防止过拟合。

To make sure you dont overfit the network you need to input the validation dataset to the network and check if the error is within some range. Because the validation set is not being using directly to adjust the weights of the netowork, therefore a good error for the validation and also the test set indicates that the network predicts well for the train set examples, also it is expected to perform well when new example are presented to the network which was not used in the training process.

Training set与Validation set都是在模型的training过程中使用的,训练过程的workflow:

for each epoch
for each training data instance
propagate error through the network
adjust the weights
calculate the accuracy over training data
for each validation data instance
calculate the accuracy over the validation data
if the threshold validation accuracy is met
exit training
else
continue training

1、I don't do a separate final training on all the training data, instead I average the responses of the 10 folded models on the test data as my final results. Which may make for better CV results, as you're guaranteed to know you're using the same models you've got CV results for.

2、I use a holdout sample (usually it was ids 0-5k), but I occasionally change the holdout sample. My hardware isn't that great so 5 fold cv is a bit time-consuming. It matched reasonably well with lb. I also use a watchlist of 20% of data to get the number of rounds before retraining, so I sort of have 3 holdout sets - lb, holdout, and watchlist.

3、I like to draw alot of samples with n = size of private-LB data to get an estimate for the private LB score (if train & test share same distribution) or n = size of public-LB to check the correlations between local score and public LB score.

4、Using a fixed hold out-set is rarely a good idea, because it's very prone to overfitting. Also take a look at the std of your CV, not just the mean. What you can do is to monitor how your k-fold scores are varying together and how your LB scores behave with respesct to that. You will like always see some patterns there, which can be used to draw some conclusions.

edit: "rarely a good idea" is misleading. Should be something like that: on datasets, where (stratified) k-fold cv is applicable, it is the safer bet compared to a single hold-out set.

5、My understanding of "out-of-fold" prediction, is that you do the following:

  1. Run K-fold CV, and for each run generate n*(1/K) predictions from training data with size n .
  2. Aggregate the K set of n*(1/K) predictions, so that you have n prediction, and this is what is referred as "out-of-fold" prediction

And what you suggest, is to sample over this out-of-fold prediction to calculate with error rate.

6、I don't wanted to state, that single hold-out sets are a no-go, they have their applications. Forecasting problems are probably the best example for that. In competitions like that (e. g. Rossmann Store Sales), k-fold cv does not work very well, because the data is not iid. Gert mentions some other examples, like splits by geo-locations. In this case, stratified-CV could be bad, but often you can still define more than one hold-out set, which have the desired distribution. Another application is to detect leakage. If you don't want to waste submissions in order to be sure that your pre-processing is leakage-free, you can create a local private test: putting some training data and treat those as test data. So don't look at this set for data exploration and do not use its labels for any pre-processing.

My point is, that a single hold-out gets overfitted faster and hence it's more dangerous to use, if you do not have a good rapport with the god of overfitting. It's easy to do the wrong things, after you got an overfitting-occured response from the LB. Besides, you do not have information regarding variance with a single hold-out. So, if you are unexperienced with all the overfitting caveats, I would suggest to prefer k-fold cv over single hold-out if applicable.

7、Nevertheless, it seems 10-fold CV with out-of-fold prediction is a very much an adequate solution

注意:

kaggle上面public leaderboards are based on validating the submissions against a random fraction of the test set and the private one's are validated against the rest of the test set. I was just going to add that private one's are released after the competition is over and the final ranking is determined based on the private leaderboard.. People can do well in public leaderboard, yet do really bad in the private one because of overfitting.

参考:

http://cvrs.whu.edu.cn/blogs/?p=154

https://www.kaggle.com/c/telstra-recruiting-network/discussion/19277

validation set以及cross validation的常见做法的更多相关文章

  1. 交叉验证(Cross Validation)简介

    参考    交叉验证      交叉验证 (Cross Validation)刘建平 一.训练集 vs. 测试集 在模式识别(pattern recognition)与机器学习(machine lea ...

  2. Cross Validation done wrong

    Cross Validation done wrong Cross validation is an essential tool in statistical learning 1 to estim ...

  3. cross validation交叉验证

    交叉验证是一种检测model是否overfit的方法.最常用的cross validation是k-fold cross validation. 具体的方法是: 1.将数据平均分成k份,0,1,2,, ...

  4. 交叉验证(cross validation)

    转自:http://www.vanjor.org/blog/2010/10/cross-validation/ 交叉验证(Cross-Validation): 有时亦称循环估计, 是一种统计学上将数据 ...

  5. 10折交叉验证(10-fold Cross Validation)与留一法(Leave-One-Out)、分层采样(Stratification)

    10折交叉验证 我们构建一个分类器,输入为运动员的身高.体重,输出为其从事的体育项目-体操.田径或篮球. 一旦构建了分类器,我们就可能有兴趣回答类似下述的问题: . 该分类器的精确率怎么样? . 该分 ...

  6. 交叉验证(Cross Validation)方法思想简介

      以下简称交叉验证(Cross Validation)为CV.CV是用来验证分类器的性能一种统计分析方法,基本思想是把在某种意义下将原始数据(dataset)进行分组,一部分做为训练集(train ...

  7. 交叉验证(Cross Validation)原理小结

    交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...

  8. 交叉验证 Cross validation

    来源:CSDN: boat_lee 简单交叉验证 hold-out cross validation 从全部训练数据S中随机选择s个样例作为训练集training set,剩余的作为测试集testin ...

  9. Cross Validation(交叉验证)

    交叉验证(Cross Validation)方法思想 Cross Validation一下简称CV.CV是用来验证分类器性能的一种统计方法. 思想:将原始数据(dataset)进行分组,一部分作为训练 ...

随机推荐

  1. python计算机基础(二)

    1. 操作系统有什么用? #1外部指令转化成0和1:#2.翻译所写的字符从繁(高低电压)至简(想做什么就做什么) :#3把一些硬件的复杂操作简化成一个一个接口. 2. 计算机由哪三大部分组成? 1.应 ...

  2. FX3 DMA生产者消费者ID代表的含义

    在开发FX3的时候,觉得赛普拉斯的库注释太少,很多时候无法理解代码含义.由于使用DMA,需要理解DMA生产者消费者代表的含义,经过多方查找,决定记录下来. 在cyu3dma.h中对SocketID进行 ...

  3. 为什么我打的jar包没有注解?

    本文来自网易云社区 作者:王飞 一.前言 一切问题的起源就是来自一个问题"为什么我打的jar包没有注解?",带着这个疑问查了一圈资料,原来问题主要是在没有将源码中的注释进行抽取打包 ...

  4. windows phone 网络开发三部曲(一)各种包的各种抓法

    首先感谢大家对我上一篇博客的支持,让我也体验了一把上榜的感觉. 这无疑是对我这个刚刚打算,认真写写博客的人的莫大的鼓励,再次感谢(鞠躬)!! 接下来想和大家分享一些关于windows phone网络开 ...

  5. appium之android_uiautomator定位

    前言 appium就是封装android的uiautomator这个框架来的,所以uiautomator的一些定位方法也可以用 text 1.通过text文本定位语法 new UiSelector() ...

  6. 九度oj 题目1369:字符串的排列

    题目描述: 输入一个字符串,按字典序打印出该字符串中字符的所有排列.例如输入字符串abc,则打印出由字符a,b,c所能排列出来的所有字符串abc,acb,bac,bca,cab和cba. 输入: 每个 ...

  7. nginx的报错500

    500:服务器内部错误,也就是服务器遇到意外情况,而无法履行请求. 500错误一般有几种情况: 1. web脚本错误,如php语法错误,lua语法错误等. 2. 访问量大的时候,由于系统资源限制,而不 ...

  8. BZOJ 4810 [Ynoi2017]由乃的玉米田 ——Bitset 莫队算法

    加法和减法的操作都能想到Bitset. 然后发现乘法比较难办,反正复杂度已经是$O(n\log{n})$了 枚举因数也不能更差了,直接枚举就好了. #include <map> #incl ...

  9. 刷题总结——bzoj1725(状压dp)

    题目: 题目描述 Farmer John 新买了一块长方形的牧场,这块牧场被划分成 N 行 M 列(1<=M<=12; 1<=N<=12),每一格都是一块正方形的土地. FJ  ...

  10. servlet分析

    Servlet生命周期分为三个阶段: 1,初始化阶段  调用init()方法 2,响应客户请求阶段 调用service()方法 3,终止阶段 调用destroy()方法 Servlet初始化阶段: 在 ...