Grid search in the tidyverse
@drsimonj here to share a tidyverse method of grid search for optimizing a model’s hyperparameters.
Grid Search
For anyone who’s unfamiliar with the term, grid search involves running a model many times with combinations of various hyperparameters. The point is to identify which hyperparameters are likely to work best. A more technical definition from Wikipedia, grid search is:
an exhaustive searching through a manually specified subset of the hyperparameter space of a learning algorithm
What this post isn’t about
To keep the focus on grid search, this post does NOT cover…
- k-fold cross-validation. Although a practically essential addition to grid search, I’ll save the combination of these techniques for a future post. If you can’t wait, check out my last post for some inspiration.
- Complex learning models. We’ll stick to a simple decision tree.
- Getting a great model fit. I’ve deliberately chosen input variables and hyperparameters that highlight the approach.
Decision tree example
Say we want to run a simple decision tree to predict cars’ transmission type (am) based on their miles per gallon (mpg) and horsepower (hp) using themtcars data set. Let’s prep the data:
library(tidyverse)
d <- mtcars %>%
# Convert `am` to factor and select relevant variables
mutate(am = factor(am, labels = c("Automatic", "Manual"))) %>%
select(am, mpg, hp)
ggplot(d, aes(mpg, hp, color = am)) +
geom_point()

For a decision tree, it looks like a step-wise function until mpg > 25, at which point it’s all Manual cars. Let’s grow a full decision tree on this data:
library(rpart)
library(rpart.plot)
# Set minsplit = 2 to fit every data point
full_fit <- rpart(am ~ mpg + hp, data = d, minsplit = 2)
prp(full_fit)

We don’t want a model like this, as it almost certainly has overfitting problems. So the question becomes, which hyperparameter specifications would work best for our model to generalize?
Training-Test Split
To help validate our hyperparameter combinations, we’ll split our data into training and test sets (in an 80/20 split):
set.seed(245)
n <- nrow(d)
train_rows <- sample(seq(n), size = .8 * n)
train <- d[ train_rows, ]
test <- d[-train_rows, ]
Create the Grid
Step one for grid search is to define our hyperparameter combinations. Say we want to test a few values for minsplit and maxdepth. I like to setup the grid of their combinations in a tidy data frame with a list and cross_d as follows:
# Define a named list of parameter values
gs <- list(minsplit = c(2, 5, 10),
maxdepth = c(1, 3, 8)) %>%
cross_d() # Convert to data frame grid
gs
#> # A tibble: 9 × 2
#> minsplit maxdepth
#> <dbl> <dbl>
#> 1 2 1
#> 2 5 1
#> 3 10 1
#> 4 2 3
#> 5 5 3
#> 6 10 3
#> 7 2 8
#> 8 5 8
#> 9 10 8
Note that the list names are the names of the hyperparameters that we want to adjust in our model function.
Create a model function
We’ll be iterating down the gs data frame to use the hyperparameter values in a rpart model. The easiest way to handle this is to define a function that accepts a row of our data frame values and passes them correctly to our model. Here’s what I’ll use:
mod <- function(...) {
rpart(am ~ hp + mpg, data = train, control = rpart.control(...))
}
Notice the argument ... is being passed to control in rpart, which is where these hyperparameters can be used.
Fit the models
Now, to fit our models, use pmap to iterate down the values. The following is iterating through each row of our gs data frame, plugging the hyperparameter values for that row into our model.
gs <- gs %>% mutate(fit = pmap(gs, mod))
gs
#> # A tibble: 9 × 3
#> minsplit maxdepth fit
#> <dbl> <dbl> <list>
#> 1 2 1 <S3: rpart>
#> 2 5 1 <S3: rpart>
#> 3 10 1 <S3: rpart>
#> 4 2 3 <S3: rpart>
#> 5 5 3 <S3: rpart>
#> 6 10 3 <S3: rpart>
#> 7 2 8 <S3: rpart>
#> 8 5 8 <S3: rpart>
#> 9 10 8 <S3: rpart>
Obtain accuracy
Next, let’s assess the performance of each fit on our test data. To handle this efficiently, let’s write another small function:
compute_accuracy <- function(fit, test_features, test_labels) {
predicted <- predict(fit, test_features, type = "class")
mean(predicted == test_labels)
}
Now apply this to each fit:
test_features <- test %>% select(-am)
test_labels <- test$am
gs <- gs %>%
mutate(test_accuracy = map_dbl(fit, compute_accuracy,
test_features, test_labels))
gs
#> # A tibble: 9 × 4
#> minsplit maxdepth fit test_accuracy
#> <dbl> <dbl> <list> <dbl>
#> 1 2 1 <S3: rpart> 0.7142857
#> 2 5 1 <S3: rpart> 0.7142857
#> 3 10 1 <S3: rpart> 0.7142857
#> 4 2 3 <S3: rpart> 0.8571429
#> 5 5 3 <S3: rpart> 0.8571429
#> 6 10 3 <S3: rpart> 0.7142857
#> 7 2 8 <S3: rpart> 0.8571429
#> 8 5 8 <S3: rpart> 0.8571429
#> 9 10 8 <S3: rpart> 0.7142857
Arrange results
To find the best model, we arrange the data based on desc(test_accuracy). The best fitting model will then be in the first row. You might see above that we have many models with the same fit. This is unusual, and likley due to the example I’ve chosen. Still, to handle this, I’ll break ties in accuracy withdesc(minsplit) and maxdepth to find the model that is most accurate and also simplest.
gs <- gs %>% arrange(desc(test_accuracy), desc(minsplit), maxdepth)
gs
#> # A tibble: 9 × 4
#> minsplit maxdepth fit test_accuracy
#> <dbl> <dbl> <list> <dbl>
#> 1 5 3 <S3: rpart> 0.8571429
#> 2 5 8 <S3: rpart> 0.8571429
#> 3 2 3 <S3: rpart> 0.8571429
#> 4 2 8 <S3: rpart> 0.8571429
#> 5 10 1 <S3: rpart> 0.7142857
#> 6 10 3 <S3: rpart> 0.7142857
#> 7 10 8 <S3: rpart> 0.7142857
#> 8 5 1 <S3: rpart> 0.7142857
#> 9 2 1 <S3: rpart> 0.7142857
It looks like a minsplit of 5 and maxdepth of 3 is the way to go!
To compare to our fully fit tree, here’s a plot of this top-performing model. Remember, it’s in the first row so we can reference [[1]].
prp(gs$fit[[1]])

Food for thought
Having the results in a tidy data frame lets us do a lot more than just pick the optimal hyperparameters. It lets us quickly wrangle with and visualize the results of the various combinations. Here are some ideas:
- Search among the top performers for the simplest model.
- Plot performance across the hyperparameter combinations.
- Save time by restricting the hypotheses before model fitting. For example, in a large data set, it’s practically pointless to try a small
minsplitand smallmaxdepth. In this case, before fitting the models, we canfilterthegsdata frame to exclude certain combinations.
Sign off
Thanks for reading and I hope this was useful for you.
For updates of recent blog posts, follow @drsimonj on Twitter, or email me atdrsimonjackson@gmail.com to get in touch.
If you’d like the code that produced this blog, check out the blogR GitHub repository.
转自:https://drsimonj.svbtle.com/grid-search-in-the-tidyverse
Grid search in the tidyverse的更多相关文章
- Comparing randomized search and grid search for hyperparameter estimation
Comparing randomized search and grid search for hyperparameter estimation Compare randomized search ...
- 3.2. Grid Search: Searching for estimator parameters
3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...
- How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras
Hyperparameter optimization is a big part of deep learning. The reason is that neural networks are n ...
- Grid Search学习
转自:https://www.cnblogs.com/ysugyl/p/8711205.html Grid Search:一种调参手段:穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性 ...
- grid search 超参数寻优
http://scikit-learn.org/stable/modules/grid_search.html 1. 超参数寻优方法 gridsearchCV 和 RandomizedSearchC ...
- scikit-learn:3.2. Grid Search: Searching for estimator parameters
參考:http://scikit-learn.org/stable/modules/grid_search.html GridSearchCV通过(蛮力)搜索參数空间(參数的全部可能组合).寻找最好的 ...
- [转载]Grid Search
[转载]Grid Search 初学机器学习,之前的模型都是手动调参的,效果一般.同学和我说他用了一个叫grid search的方法.可以实现自动调参,顿时感觉非常高级.吃饭的时候想调参的话最差不过也 ...
- grid search
sklearn.metrics.make_scorer(score_func, greater_is_better=True, needs_proba=False, needs_threshold=F ...
- Hackerrank - The Grid Search
https://www.hackerrank.com/challenges/the-grid-search/forum 今天碰见这题,看见难度是Moderate,觉得应该能半小时内搞定. 读完题目发现 ...
随机推荐
- MySql字符串函数使用技巧
1.从左开始截取字符串 left(str, length) 说明:left(被截取字段,截取长度) 例:select left(content,200) as abstract from my_con ...
- 配置WampServer以及搭建WordPress的一些问题,持续总结。
这里用的版本是Wampserver2.4-x64. Wamp的安装就不赘述了,一路点通过就可以了. #注意:(最好别改,省的麻烦) 80端口是Apache 的默认端口,在httpd.conf文件中配置 ...
- 数列分段Section II
洛谷传送门 输入时处理出最小的答案和最大的答案,然后二分答案即可. 其余细节看代码 #include <iostream> #include <cstdio> using na ...
- sql server 数值的四舍五入
sql中的四舍五入通常会有round 和cast( …… as decimal())两种方式: 个人建议使用cast 方式: 方式1: 经过试验,同样都可以做到四舍五入,但round如下实例1会报 ...
- JS里引用CSS属性时候的命名
如果JS代码中设置<p>元素的另一个CSS属性font-family.这个属性的获取方式与color属性略有不同,因为 font和family之间的连字符与JS中减法操作符相同,J ...
- struts2之拦截器
1. 为什么需要拦截器 早期MVC框架将一些通用操作写死在核心控制器中,致使框架灵活性不足.可扩展性降低, Struts 2将核心功能放到多个拦截器中实现,拦截器可自由选择和组合,增强了灵活性,有利于 ...
- Redis和Spring整合
Redis和Spring整合 Redis在这篇里就不做介绍了~以后系统的学学,然后整理写出来. 首先是环境的搭建 通过自己引包的方式,将redis和spring-redis的包引到自己的项目中,我项目 ...
- jwt token Example - Python
0 Pre Install Python3 Install PyCrypto Install PyJWT 1 token 由三部分组成 header, payload, sign 并用逗号连接各部分 ...
- JS判断值是否是正数
1.使用isNaN()函数 isNaN()的缺点就在于 null.空格以及空串会被按照0来处理 NaN: Not a Number /** *判断是否是数字 * **/ function isReal ...
- ue4竖排文本显示
最近发现中国风游戏中,经常会遇到旁白文字竖着显示的需求. 于是我首先找了找控件蓝图中的text有没有相关类似横竖文本框的选项,然而并无所获. 突然间灵机一动! 竖着显示不就是每个字一换行嘛! 说干就干 ...