转自:https://www.cnblogs.com/ysugyl/p/8711205.html

Grid Search:一种调参手段;穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果。其原理就像是在数组里找最大值。(为什么叫网格搜索?以有两个参数的模型为例,参数a有3种可能,参数b有4种可能,把所有可能性列出来,可以表示成一个3*4的表格,其中每个cell就是一个网格,循环过程就像是在每个网格里遍历、搜索,所以叫grid search)

1.简单的网格搜索

from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split iris = load_iris()
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=0)
print("Size of training set:{} size of testing set:{}".format(X_train.shape[0],X_test.shape[0])) #### grid search start
best_score = 0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)#对于每种参数可能的组合,进行一次训练;
svm.fit(X_train,y_train)
score = svm.score(X_test,y_test)
if score > best_score:#找到表现最好的参数
best_score = score
best_parameters = {'gamma':gamma,'C':C}
#### grid search end print("Best score:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))

输出:

Size of training set:112 size of testing set:38
Best score:0.973684
Best parameters:{'gamma': 0.001, 'C': 100}

存在的问题:

原始数据集划分成训练集和测试集以后,其中测试集除了用作调整参数,也用来测量模型的好坏;这样做导致最终的评分结果比实际效果要好。(因为测试集在调参过程中,送到了模型里,而我们的目的是将训练模型应用在unseen data上);

解决方法:

对训练集再进行一次划分,分成训练集和验证集,这样划分的结果就是:原始数据划分为3份,分别为:训练集、验证集和测试集;其中训练集用来模型训练,验证集用来调整参数,而测试集用来衡量模型表现好坏。

2.使用验证集调整参数

X_trainval,X_test,y_trainval,y_test = train_test_split(iris.data,iris.target,random_state=0)
X_train,X_val,y_train,y_val = train_test_split(X_trainval,y_trainval,random_state=1)
print("Size of training set:{} size of validation set:{} size of teseting set:{}".format(X_train.shape[0],X_val.shape[0],X_test.shape[0])) best_score = 0.0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)
svm.fit(X_train,y_train)
score = svm.score(X_val,y_val)
if score > best_score:
best_score = score
best_parameters = {'gamma':gamma,'C':C}
svm = SVC(**best_parameters) #使用最佳参数,构建新的模型
svm.fit(X_trainval,y_trainval) #使用训练集和验证集进行训练,more data always results in good performance.
test_score = svm.score(X_test,y_test) # evaluation模型评估
print("Best score on validation set:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
print("Best score on test set:{:.2f}".format(test_score))

输出:

Size of training set:84 size of validation set:28 size of teseting set:38
Best score on validation set:0.96
Best parameters:{'gamma': 0.001, 'C': 10}
Best score on test set:0.92

然而,这种间的的grid search方法,其最终的表现好坏与初始数据的划分结果有很大的关系,为了处理这种情况,我们采用交叉验证的方式来减少偶然性。

3.使用交叉验证方法调参

from sklearn.model_selection import cross_val_score

best_score = 0.0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)
scores = cross_val_score(svm,X_trainval,y_trainval,cv=5) #5折交叉验证
score = scores.mean() #取平均数
if score > best_score:
best_score = score
best_parameters = {"gamma":gamma,"C":C}
svm = SVC(**best_parameters)
svm.fit(X_trainval,y_trainval)
test_score = svm.score(X_test,y_test)
print("Best score on validation set:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
print("Score on testing set:{:.2f}".format(test_score))

输出:

Best score on validation set:0.97
Best parameters:{'gamma': 0.01, 'C': 100}
Score on testing set:0.97

交叉验证经常与网格搜索进行结合,作为参数评价的一种方法,这种方法叫做grid search with cross validation。

4.类GridSearchCV综合

from sklearn.model_selection import GridSearchCV

#把要调整的参数以及其候选值 列出来;
param_grid = {"gamma":[0.001,0.01,0.1,1,10,100],
"C":[0.001,0.01,0.1,1,10,100]}
print("Parameters:{}".format(param_grid)) grid_search = GridSearchCV(SVC(),param_grid,cv=5) #实例化一个GridSearchCV类
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=10)
grid_search.fit(X_train,y_train) #训练,找到最优的参数,同时使用最优的参数实例化一个新的SVC estimator。
print("Test set score:{:.2f}".format(grid_search.score(X_test,y_test)))
print("Best parameters:{}".format(grid_search.best_params_))
print("Best score on train set:{:.2f}".format(grid_search.best_score_))

输出:

Parameters:{'gamma': [0.001, 0.01, 0.1, 1, 10, 100], 'C': [0.001, 0.01, 0.1, 1, 10, 100]}
Test set score:0.97
Best parameters:{'C': 10, 'gamma': 0.1}
Best score on train set:0.98

sklearn设计了一个这样的类GridSearchCV,这个类实现了fit,predict,score等方法,被当做了一个estimator,使用fit方法,该过程中:(1)搜索到最佳参数;(2)实例化了一个最佳参数的estimator;

5.总结

Grid Search:一种调优方法,在参数列表中进行穷举搜索,对每种情况进行训练,找到最优的参数;由此可知,这种方法的主要缺点是 比较耗时!

Grid Search学习的更多相关文章

  1. [转载]Grid Search

    [转载]Grid Search 初学机器学习,之前的模型都是手动调参的,效果一般.同学和我说他用了一个叫grid search的方法.可以实现自动调参,顿时感觉非常高级.吃饭的时候想调参的话最差不过也 ...

  2. Comparing randomized search and grid search for hyperparameter estimation

    Comparing randomized search and grid search for hyperparameter estimation Compare randomized search ...

  3. 3.2. Grid Search: Searching for estimator parameters

    3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...

  4. Grid search in the tidyverse

    @drsimonj here to share a tidyverse method of grid search for optimizing a model's hyperparameters. ...

  5. How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras

    Hyperparameter optimization is a big part of deep learning. The reason is that neural networks are n ...

  6. grid search 超参数寻优

    http://scikit-learn.org/stable/modules/grid_search.html 1. 超参数寻优方法 gridsearchCV 和  RandomizedSearchC ...

  7. CSS Grid 布局学习笔记

    CSS Grid 布局学习笔记 好久没有写博客了, MDN 上关于 Grid 布局的知识比较零散, 正好根据我这几个月的实践对 CSS Grid 布局做一个总结, 以备查阅. 1. 基础用法 Grid ...

  8. scikit-learn:3.2. Grid Search: Searching for estimator parameters

    參考:http://scikit-learn.org/stable/modules/grid_search.html GridSearchCV通过(蛮力)搜索參数空间(參数的全部可能组合).寻找最好的 ...

  9. Elastic Search 学习之路(三)—— tutorial demo

    一.ElasticSearch tutorial demo example 1. 单机.local.CRUD操作 实现方式: SpringBoot + ElasticSearch 拷贝的小demo,原 ...

随机推荐

  1. mysql中如何在命令行中,执行一个SQL脚本文件?

    需求描述: 在mysql数据库的使用中,有的时候,需要直接在shell的命令行中,执行某个SQL脚本文件, 比如,要初始化数据库,创建特定的存储过程,创建表等操作,这里进行一个基本的测试. 一般情况, ...

  2. swift - UIScrollView 的使用

    本节详细介绍scrollview的用法 ———————————————————————————————————— UIScrollView 是一个能够滚动的视图控件,可以用来展示大量的内容,并且可以通 ...

  3. Python 字符串处理(转)

    转自:黄聪:Python 字符串操作(替换.删除.截取.复制.连接.比较.查找.包含.大小写转换.分割等) http://www.cnblogs.com/huangcong/archive/2011/ ...

  4. Python 进阶(三)面向对象编程基础

    aaarticlea/png;base64,iVBORw0KGgoAAAANSUhEUgAAAkMAAAFGCAIAAADmfgziAAAgAElEQVR4nOx993vT1v7/93/5EEt2Eg

  5. resolution will not be reattempted until the update interval of nexus has elapsed or updates are forced

    Maven在执行中报错: - Failure to transfer org.slf4j:slf4j-api:jar:1.7.24 from http://localhost:8081/nexus/c ...

  6. 多线程模块:threading

    threading 常见用法: (1) 在 thread 中,我们是通过 thread.start_new_thread(function, args) 来开启一个线程,接收一个函数和该函数的参数,函 ...

  7. oracle常用管理命令

    启动数据库和监听 lsnrctl start sqlplus /nolog conn sys/as sysdba startup  查看当前的实例名 show parameter instance_n ...

  8. LeetCode——Contains Duplicate II

    Description: Given an array of integers and an integer k, find out whether there there are two disti ...

  9. 报错 ERROR in static/js/vendor.b3f56e9e0cd56988d890.js from UglifyJs

    开发vux项目在引入 // 表单验证组件-start import zh_CN from 'vee-validate/dist/locale/zh_CN' import Validator from ...

  10. all index range ref eq_ref const system 索引type说明

    背景知识 在使用sql的过程中经常需要建立索引,而每种索引是怎么处罚的又是怎么起到作用的,首先必须知道索引和索引的类型. 索引类型type 我们可以清楚的看到type那一栏有index ALL eq_ ...