sklearn.model_selection Part 2: Model validation
1. check_cv()
def check_cv(cv=3, y=None, classifier=False):
if cv is None:
cv = 3
if isinstance(cv, numbers.Integral):
# 如果classifier为True 并且y 是 二类或者多类,就返回 StratifiedKFold,否则返回KFold
if (classifier and (y is not None) and
(type_of_target(y) in ('binary', 'multiclass'))):
return StratifiedKFold(cv)
else:
return KFold(cv)
# if not hasattr(cv, 'split') or isinstance(cv, str):
# if not isinstance(cv, Iterable) or isinstance(cv, str):
# raise ValueError("Expected cv as an integer, cross-validation "
# "object (from sklearn.model_selection) "
# "or an iterable. Got %s." % cv)
# return _CVIterableWrapper(cv)
return cv # New style cv objects are passed without any modification
阅读源代码要抓主干,所以我把细枝末节的代码注释掉了。
2. cross_validate()
这个函数的代码有点复杂,讲解其他有用的代码。
从 这里 可以找到 scoring的名字对应的函数
注意:
得分函数(score function)是返回的值越高越好,而损失函数(loss function)是返回的值越低越好。原来scoring中的"mean_squared_error"已经改成了"neg_mean_squared_error"。
3. cross_val_score()
返回一个 estimator 做 k 折交叉验证后产生的 k 个 在测试集上的得分。返回值是一个有k个元素的数组。
4. cross_val_predict()
cross_val_predict提供了和cross_val_score相似的接口,但是后者返回k次得分,而前者返回所有数据集上的预测结果。如果传入的参数 method = 'predict'(默认情况),那么返回(n_samples, ) 形状的ndarray,如果传入的参数 method = 'predict_proba',那么返回(n_samples, n_classes) 形状的ndarray。
5. learning_curve()
def learning_curve(estimator, X, y, groups=None,
train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None,
exploit_incremental_learning=False, n_jobs=1,
pre_dispatch="all", verbose=0, shuffle=False,
random_state=None):
# if exploit_incremental_learning and not hasattr(estimator, "partial_fit"):
# raise ValueError("An estimator must support the partial_fit interface "
# "to exploit incremental learning")
X, y, groups = indexable(X, y, groups)
cv = check_cv(cv, y, classifier=is_classifier(estimator)) # 默认是KFold(3)
# Store it as list as we will be iterating over the list multiple times
cv_iter = list(cv.split(X, y, groups))
scorer = check_scoring(estimator, scoring=scoring)
n_max_training_samples = len(cv_iter[0][0]) # 取出的是第一折中的train的数目
# Because the lengths of folds can be significantly different, it is
# not guaranteed that we use all of the available training data when we
# use the first 'n_max_training_samples' samples.
# 因为不同折中的数据数目可能会是不同的,当我们使用第一个 'n_max_training_samples'
# 并不能保证我们使用所有可用的训练数据。(第一折长度一定是最短的)
train_sizes_abs = _translate_train_sizes(train_sizes, # 将train_sizes中样本比例转换为具体数目(绝对size)
n_max_training_samples)
n_unique_ticks = train_sizes_abs.shape[0]
if verbose > 0:
print("[learning_curve] Training set sizes: " + str(train_sizes_abs))
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
verbose=verbose)
if shuffle:
rng = check_random_state(random_state)
cv_iter = ((rng.permutation(train), test) for train, test in cv_iter)
# if exploit_incremental_learning: # 默认为False,暂时先忽略
# classes = np.unique(y) if is_classifier(estimator) else None
# out = parallel(delayed(_incremental_fit_estimator)(
# clone(estimator), X, y, classes, train, test, train_sizes_abs,
# scorer, verbose) for train, test in cv_iter)
else:
train_test_proportions = []
for train, test in cv_iter: # 在每一折中train有一个逐渐增大的变化,test不变
for n_train_samples in train_sizes_abs:
train_test_proportions.append((train[:n_train_samples], test))
out = parallel(delayed(_fit_and_score)(
clone(estimator), X, y, scorer, train, test,
verbose, parameters=None, fit_params=None, return_train_score=True)
for train, test in train_test_proportions)
out = np.array(out)
n_cv_folds = out.shape[0] // n_unique_ticks
out = out.reshape(n_cv_folds, n_unique_ticks, 2)
out = np.asarray(out).transpose((2, 1, 0))
函数的返回值:
- train_sizes_abs :array, shape=(n_unique_ticks),曲线上每个点对应的训练数据集的size
- train_scores:array, shape=(n_ticks, n_cv_folds),所有的在训练集上的分数。
- test_scores: array, shape=(n_ticks, n_cv_folds),所有的在测试集上的分数。
注释掉了不重要的代码便于分析。该方法就是计算交叉验证中对于不同训练数据集大小的训练分数和测试分数。算法的中文描述如下:
- 将数据做 k(默认为3)折划分。
- 在每一折的验证中,有一个参数train_sizes (默认为np.linspace(0.1, 1.0, 5) = array([0.1 , 0.325, 0.55 , 0.775, 1. ])表示训练集依次取这一折中train_set的0.1比例的数据,0.325比例的数据,0.55比例的数据,0.775比例的数据,1.0比例的数据,然后分别和这一折中 test_set 组成新的训练集—测试集对,分别计算每一对上的训练分数和测试分数。
如果指定折数k,train_sizes的长度为m,那么训练estimator并验证其性能的过程要重复 k*m次。简单的就是在每一折上看训练集逐渐变大时,estimator的性能的变化情况。
这个函数计算的结果如何可视化的代码可以参考 这里
6. validation_curve()
计算在不同参数取值下的训练集分数和测试集分数。
需要设置的参数:
validation_curve(estimator, X, y, param_name, param_range)
param_name:string,参数名字
param_range: array-like,参数范围
返回的结果:
train_scores : array, shape (n_ticks, n_cv_folds),Scores on training sets.
test_scores : array, shape (n_ticks, n_cv_folds),Scores on test set.
可视化的代码可以参考 这里
7. permutation_test_score
不是很重要,几乎没见有用到过
sklearn.model_selection Part 2: Model validation的更多相关文章
- sklearn.model_selection 的train_test_split方法和参数
train_test_split是sklearn中用于划分数据集,即将原始数据集划分成测试集和训练集两部分的函数. from sklearn.model_selection import train_ ...
- Model Validation in ASP.NET Web API
Model Validation in ASP.NET Web API 原文:http://www.asp.net/web-api/overview/formats-and-model-binding ...
- <转>ASP.NET学习笔记之MVC 3 数据验证 Model Validation 详解
MVC 3 数据验证 Model Validation 详解 再附加一些比较好的验证详解:(以下均为引用) 1.asp.net mvc3 的数据验证(一) - zhangkai2237 - 博客园 ...
- Model Validation in Asp.net MVC
原文:Model Validation in Asp.net MVC 本文用于记录Pro ASP.NET MVC 3 Framework中阐述的数据验证的方式. 先说服务器端的吧.最简单的一种方式自然 ...
- Model Validation(模型验证)
Model Validation(模型验证) 前言 阅读本文之前,您也可以到Asp.Net Web API 2 系列导航进行查看 http://www.cnblogs.com/aehyok/p/344 ...
- sklearn.model_selection 的 train_test_split作用
train_test_split函数用于将数据划分为训练数据和测试数据. train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train_data和test_data ...
- sklearn.model_selection.StratifiedShuffleSplit
sklearn.model_selection.StratifiedShuffleSplit
- sklearn.model_selection模块
后续补代码 sklearn.model_selection模块的几个方法参数
- 【sklearn】网格搜索 from sklearn.model_selection import GridSearchCV
GridSearchCV用于系统地遍历模型的多种参数组合,通过交叉验证确定最佳参数. 1.GridSearchCV参数 # 不常用的参数 pre_dispatch 没看懂 refit 默认为Tr ...
随机推荐
- win10 64支持承载网络
在intel官网找到对应型号的网卡驱动. 下载win7版本的,更新驱动.安装完毕之后还要在设备管理里面更新2019 7 30这个版本的驱动. 英特尔® PROSet/无线软件和面向 IT 管理员的驱动 ...
- python-redis-订阅和发布
发布:redishelper.py import redis class RedisHelper: def __init__(self): self.__conn = redis.Redis(host ...
- Thinkphp5.0快速入门笔记(3)
学习来源与说明 https://www.kancloud.cn/thinkphp/thinkphp5_quickstart 测试与部署均在windows10下进行学习. 快速入门第三节 获取当前的请求 ...
- spring boot本地开发与docker容器化部署的差异
spring boot本地开发与docker容器化部署的差异: 1. 文件路径及文件名区别大小写: 本地开发环境为windows操作系统,是忽略大小写的,但容器中区分大小写 2. docker中的容器 ...
- 【Java面试题】解释内存中的栈(stack)、堆(heap)和静态存储区的用法
Java面试题:解释内存中的栈(stack).堆(heap)和静态存储区的用法 堆区: 专门用来保存对象的实例(new 创建的对象和数组),实际上也只是保存对象实例的属性值,属性的类型和对象本身的类型 ...
- Caffe常用算子GPU和CPU对比
通过整理LeNet.AlexNet.VGG16.googLeNet.ResNet.MLP统计出的常用算子(不包括ReLU),表格是对比. Prelu Cpu版 Gpu版 for (int i = 0; ...
- linux下sendmail邮件系统安装详情
介绍 sendmail是linux系统中一个邮箱系统,如果我们在系统中配置好sendmail就可以直接使用它来发送邮箱.sendmail的配置文件/etc/mail/sendmail.cf ...
- 织梦DEDEcms5.7解决arclist标签调用副栏目文章
使用arclist标签调用文章的时候才发现,根本无法调用相关文章. 下面给出解决办法,希望帮到需要的人. 找到/include/taglib/arclist.lib.php文件然后打开.然后在大约30 ...
- maven 依赖包找不到 (转)
1,手动添加jar包 例: maven在集成Oracle驱动的时候从远程仓库下载不下来ojdbc14 报missing artifact com.oracle:ojdbc14:jar:10.2.0.3 ...
- RAID原理详解
RAID 0(stripe,条带化存储):在RAID级别中最高的存储性能. 原理:是把连续的数据分散到多个磁盘上存取,系统有数据请求就可以被多个磁盘并行的执行,每个磁盘执行属于他自己的那部分数据请求. ...