Grid Search学习
转自:https://www.cnblogs.com/ysugyl/p/8711205.html
Grid Search:一种调参手段;穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果。其原理就像是在数组里找最大值。(为什么叫网格搜索?以有两个参数的模型为例,参数a有3种可能,参数b有4种可能,把所有可能性列出来,可以表示成一个3*4的表格,其中每个cell就是一个网格,循环过程就像是在每个网格里遍历、搜索,所以叫grid search)
1.简单的网格搜索
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split iris = load_iris()
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=0)
print("Size of training set:{} size of testing set:{}".format(X_train.shape[0],X_test.shape[0])) #### grid search start
best_score = 0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)#对于每种参数可能的组合,进行一次训练;
svm.fit(X_train,y_train)
score = svm.score(X_test,y_test)
if score > best_score:#找到表现最好的参数
best_score = score
best_parameters = {'gamma':gamma,'C':C}
#### grid search end print("Best score:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
输出:
Size of training set:112 size of testing set:38
Best score:0.973684
Best parameters:{'gamma': 0.001, 'C': 100}
存在的问题:
原始数据集划分成训练集和测试集以后,其中测试集除了用作调整参数,也用来测量模型的好坏;这样做导致最终的评分结果比实际效果要好。(因为测试集在调参过程中,送到了模型里,而我们的目的是将训练模型应用在unseen data上);
解决方法:
对训练集再进行一次划分,分成训练集和验证集,这样划分的结果就是:原始数据划分为3份,分别为:训练集、验证集和测试集;其中训练集用来模型训练,验证集用来调整参数,而测试集用来衡量模型表现好坏。
2.使用验证集调整参数
X_trainval,X_test,y_trainval,y_test = train_test_split(iris.data,iris.target,random_state=0)
X_train,X_val,y_train,y_val = train_test_split(X_trainval,y_trainval,random_state=1)
print("Size of training set:{} size of validation set:{} size of teseting set:{}".format(X_train.shape[0],X_val.shape[0],X_test.shape[0])) best_score = 0.0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)
svm.fit(X_train,y_train)
score = svm.score(X_val,y_val)
if score > best_score:
best_score = score
best_parameters = {'gamma':gamma,'C':C}
svm = SVC(**best_parameters) #使用最佳参数,构建新的模型
svm.fit(X_trainval,y_trainval) #使用训练集和验证集进行训练,more data always results in good performance.
test_score = svm.score(X_test,y_test) # evaluation模型评估
print("Best score on validation set:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
print("Best score on test set:{:.2f}".format(test_score))
输出:
Size of training set:84 size of validation set:28 size of teseting set:38
Best score on validation set:0.96
Best parameters:{'gamma': 0.001, 'C': 10}
Best score on test set:0.92
然而,这种间的的grid search方法,其最终的表现好坏与初始数据的划分结果有很大的关系,为了处理这种情况,我们采用交叉验证的方式来减少偶然性。
3.使用交叉验证方法调参
from sklearn.model_selection import cross_val_score best_score = 0.0
for gamma in [0.001,0.01,0.1,1,10,100]:
for C in [0.001,0.01,0.1,1,10,100]:
svm = SVC(gamma=gamma,C=C)
scores = cross_val_score(svm,X_trainval,y_trainval,cv=5) #5折交叉验证
score = scores.mean() #取平均数
if score > best_score:
best_score = score
best_parameters = {"gamma":gamma,"C":C}
svm = SVC(**best_parameters)
svm.fit(X_trainval,y_trainval)
test_score = svm.score(X_test,y_test)
print("Best score on validation set:{:.2f}".format(best_score))
print("Best parameters:{}".format(best_parameters))
print("Score on testing set:{:.2f}".format(test_score))
输出:
Best score on validation set:0.97
Best parameters:{'gamma': 0.01, 'C': 100}
Score on testing set:0.97
交叉验证经常与网格搜索进行结合,作为参数评价的一种方法,这种方法叫做grid search with cross validation。
4.类GridSearchCV综合
from sklearn.model_selection import GridSearchCV #把要调整的参数以及其候选值 列出来;
param_grid = {"gamma":[0.001,0.01,0.1,1,10,100],
"C":[0.001,0.01,0.1,1,10,100]}
print("Parameters:{}".format(param_grid)) grid_search = GridSearchCV(SVC(),param_grid,cv=5) #实例化一个GridSearchCV类
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=10)
grid_search.fit(X_train,y_train) #训练,找到最优的参数,同时使用最优的参数实例化一个新的SVC estimator。
print("Test set score:{:.2f}".format(grid_search.score(X_test,y_test)))
print("Best parameters:{}".format(grid_search.best_params_))
print("Best score on train set:{:.2f}".format(grid_search.best_score_))
输出:
Parameters:{'gamma': [0.001, 0.01, 0.1, 1, 10, 100], 'C': [0.001, 0.01, 0.1, 1, 10, 100]}
Test set score:0.97
Best parameters:{'C': 10, 'gamma': 0.1}
Best score on train set:0.98
sklearn设计了一个这样的类GridSearchCV,这个类实现了fit,predict,score等方法,被当做了一个estimator,使用fit方法,该过程中:(1)搜索到最佳参数;(2)实例化了一个最佳参数的estimator;
5.总结
Grid Search:一种调优方法,在参数列表中进行穷举搜索,对每种情况进行训练,找到最优的参数;由此可知,这种方法的主要缺点是 比较耗时!
Grid Search学习的更多相关文章
- [转载]Grid Search
[转载]Grid Search 初学机器学习,之前的模型都是手动调参的,效果一般.同学和我说他用了一个叫grid search的方法.可以实现自动调参,顿时感觉非常高级.吃饭的时候想调参的话最差不过也 ...
- Comparing randomized search and grid search for hyperparameter estimation
Comparing randomized search and grid search for hyperparameter estimation Compare randomized search ...
- 3.2. Grid Search: Searching for estimator parameters
3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...
- Grid search in the tidyverse
@drsimonj here to share a tidyverse method of grid search for optimizing a model's hyperparameters. ...
- How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras
Hyperparameter optimization is a big part of deep learning. The reason is that neural networks are n ...
- grid search 超参数寻优
http://scikit-learn.org/stable/modules/grid_search.html 1. 超参数寻优方法 gridsearchCV 和 RandomizedSearchC ...
- CSS Grid 布局学习笔记
CSS Grid 布局学习笔记 好久没有写博客了, MDN 上关于 Grid 布局的知识比较零散, 正好根据我这几个月的实践对 CSS Grid 布局做一个总结, 以备查阅. 1. 基础用法 Grid ...
- scikit-learn:3.2. Grid Search: Searching for estimator parameters
參考:http://scikit-learn.org/stable/modules/grid_search.html GridSearchCV通过(蛮力)搜索參数空间(參数的全部可能组合).寻找最好的 ...
- Elastic Search 学习之路(三)—— tutorial demo
一.ElasticSearch tutorial demo example 1. 单机.local.CRUD操作 实现方式: SpringBoot + ElasticSearch 拷贝的小demo,原 ...
随机推荐
- mysqldump进行数据库的全备时,备份数据库的顺序是什么,就是先备份哪个库,然后再备份哪个库
需求描述: 今天在用mysqldump工具进行数据库的备份的时候,突然想了一个问题,比如我有10个库要进行备份 那么是先备份哪个,然后再备份哪个呢,所以,做了实验,验证下. 操作过程: 1.使用--a ...
- [转]ASP.NET MVC 5 -从控制器访问数据模型
在本节中,您将创建一个新的MoviesController类,并在这个Controller类里编写代码来取得电影数据,并使用视图模板将数据展示在浏览器里. 在开始下一步前,先Build一下应用程序(生 ...
- HTML&CSS精选笔记_CSS高级技巧
CSS高级技巧 CSS精灵技术 需求分析 CSS精灵是一种处理网页背景图像的方式.它将一个页面涉及到的所有零星背景图像都集中到一张大图中去,然后将大图应用于网页,这样,当用户访问该页面时,只需向服务发 ...
- ArcGIS ArcPy Python处理数据
1.使用搜索游标查看行中的字段值.import arcpy # Set the workspace arcpy.env.workspace = "c:/base/data.gdb" ...
- unable to execute dex:GC overhead limit exceeded unable to execute dex:java heap space 解决方案
最近做厂商适配,厂商提供了一部分Framework的jar包,把jar包通过Add Jar放到Build Path中, 在生成APK过程中,Eclipse长时间停留在100%那个进度. 最后Eclip ...
- 格式化输出%s和%S的区别
使用s时,printf是针对单字节字符的字符串,而wprintf是针对宽字符的 使用S时,正好相反,printf针对宽字符 CString中的format与printf类似,在unicode字符集的工 ...
- DiscuzX的目录权限设置1
经常有朋友遇到Discuz目录权限设置出错的问题,网上千奇百怪的教程非常多,所谓的终极安全的教程更是满天飞,各种所谓的安全加强软件也随处可见,可实际过程中发现,老手用不上,新手则只会因为这些东西徒增麻 ...
- php学习三:函数
1. php中的函数和js中的区别 在php中,函数的形参可以给一个默认值,若有实参的传递则函数使用传递过来的参数,没有的话显示默认值 代码如下: function showSelf($name=& ...
- 17,UC(06)
/* 达内学习 UC day06 2013-10-10 */ 回忆过去: 系统调用 - UNIX操作系统提供的一些列函数皆苦,用于访问内核空间,遵循posix规范 文件操作:open()\rea ...
- serializeArray()与serialize()的区别
serialize()序列化表单元素为字符串,用于 Ajax 请求. serializeArray()序列化表单元素为JSON数据. <script type="text/javasc ...