Grid search in the tidyverse
@drsimonj here to share a tidyverse method of grid search for optimizing a model’s hyperparameters.
Grid Search
For anyone who’s unfamiliar with the term, grid search involves running a model many times with combinations of various hyperparameters. The point is to identify which hyperparameters are likely to work best. A more technical definition from Wikipedia, grid search is:
an exhaustive searching through a manually specified subset of the hyperparameter space of a learning algorithm
What this post isn’t about
To keep the focus on grid search, this post does NOT cover…
- k-fold cross-validation. Although a practically essential addition to grid search, I’ll save the combination of these techniques for a future post. If you can’t wait, check out my last post for some inspiration.
- Complex learning models. We’ll stick to a simple decision tree.
- Getting a great model fit. I’ve deliberately chosen input variables and hyperparameters that highlight the approach.
Decision tree example
Say we want to run a simple decision tree to predict cars’ transmission type (am) based on their miles per gallon (mpg) and horsepower (hp) using themtcars data set. Let’s prep the data:
library(tidyverse)
d <- mtcars %>%
# Convert `am` to factor and select relevant variables
mutate(am = factor(am, labels = c("Automatic", "Manual"))) %>%
select(am, mpg, hp)
ggplot(d, aes(mpg, hp, color = am)) +
geom_point()

For a decision tree, it looks like a step-wise function until mpg > 25, at which point it’s all Manual cars. Let’s grow a full decision tree on this data:
library(rpart)
library(rpart.plot)
# Set minsplit = 2 to fit every data point
full_fit <- rpart(am ~ mpg + hp, data = d, minsplit = 2)
prp(full_fit)

We don’t want a model like this, as it almost certainly has overfitting problems. So the question becomes, which hyperparameter specifications would work best for our model to generalize?
Training-Test Split
To help validate our hyperparameter combinations, we’ll split our data into training and test sets (in an 80/20 split):
set.seed(245)
n <- nrow(d)
train_rows <- sample(seq(n), size = .8 * n)
train <- d[ train_rows, ]
test <- d[-train_rows, ]
Create the Grid
Step one for grid search is to define our hyperparameter combinations. Say we want to test a few values for minsplit and maxdepth. I like to setup the grid of their combinations in a tidy data frame with a list and cross_d as follows:
# Define a named list of parameter values
gs <- list(minsplit = c(2, 5, 10),
maxdepth = c(1, 3, 8)) %>%
cross_d() # Convert to data frame grid
gs
#> # A tibble: 9 × 2
#> minsplit maxdepth
#> <dbl> <dbl>
#> 1 2 1
#> 2 5 1
#> 3 10 1
#> 4 2 3
#> 5 5 3
#> 6 10 3
#> 7 2 8
#> 8 5 8
#> 9 10 8
Note that the list names are the names of the hyperparameters that we want to adjust in our model function.
Create a model function
We’ll be iterating down the gs data frame to use the hyperparameter values in a rpart model. The easiest way to handle this is to define a function that accepts a row of our data frame values and passes them correctly to our model. Here’s what I’ll use:
mod <- function(...) {
rpart(am ~ hp + mpg, data = train, control = rpart.control(...))
}
Notice the argument ... is being passed to control in rpart, which is where these hyperparameters can be used.
Fit the models
Now, to fit our models, use pmap to iterate down the values. The following is iterating through each row of our gs data frame, plugging the hyperparameter values for that row into our model.
gs <- gs %>% mutate(fit = pmap(gs, mod))
gs
#> # A tibble: 9 × 3
#> minsplit maxdepth fit
#> <dbl> <dbl> <list>
#> 1 2 1 <S3: rpart>
#> 2 5 1 <S3: rpart>
#> 3 10 1 <S3: rpart>
#> 4 2 3 <S3: rpart>
#> 5 5 3 <S3: rpart>
#> 6 10 3 <S3: rpart>
#> 7 2 8 <S3: rpart>
#> 8 5 8 <S3: rpart>
#> 9 10 8 <S3: rpart>
Obtain accuracy
Next, let’s assess the performance of each fit on our test data. To handle this efficiently, let’s write another small function:
compute_accuracy <- function(fit, test_features, test_labels) {
predicted <- predict(fit, test_features, type = "class")
mean(predicted == test_labels)
}
Now apply this to each fit:
test_features <- test %>% select(-am)
test_labels <- test$am
gs <- gs %>%
mutate(test_accuracy = map_dbl(fit, compute_accuracy,
test_features, test_labels))
gs
#> # A tibble: 9 × 4
#> minsplit maxdepth fit test_accuracy
#> <dbl> <dbl> <list> <dbl>
#> 1 2 1 <S3: rpart> 0.7142857
#> 2 5 1 <S3: rpart> 0.7142857
#> 3 10 1 <S3: rpart> 0.7142857
#> 4 2 3 <S3: rpart> 0.8571429
#> 5 5 3 <S3: rpart> 0.8571429
#> 6 10 3 <S3: rpart> 0.7142857
#> 7 2 8 <S3: rpart> 0.8571429
#> 8 5 8 <S3: rpart> 0.8571429
#> 9 10 8 <S3: rpart> 0.7142857
Arrange results
To find the best model, we arrange the data based on desc(test_accuracy). The best fitting model will then be in the first row. You might see above that we have many models with the same fit. This is unusual, and likley due to the example I’ve chosen. Still, to handle this, I’ll break ties in accuracy withdesc(minsplit) and maxdepth to find the model that is most accurate and also simplest.
gs <- gs %>% arrange(desc(test_accuracy), desc(minsplit), maxdepth)
gs
#> # A tibble: 9 × 4
#> minsplit maxdepth fit test_accuracy
#> <dbl> <dbl> <list> <dbl>
#> 1 5 3 <S3: rpart> 0.8571429
#> 2 5 8 <S3: rpart> 0.8571429
#> 3 2 3 <S3: rpart> 0.8571429
#> 4 2 8 <S3: rpart> 0.8571429
#> 5 10 1 <S3: rpart> 0.7142857
#> 6 10 3 <S3: rpart> 0.7142857
#> 7 10 8 <S3: rpart> 0.7142857
#> 8 5 1 <S3: rpart> 0.7142857
#> 9 2 1 <S3: rpart> 0.7142857
It looks like a minsplit of 5 and maxdepth of 3 is the way to go!
To compare to our fully fit tree, here’s a plot of this top-performing model. Remember, it’s in the first row so we can reference [[1]].
prp(gs$fit[[1]])

Food for thought
Having the results in a tidy data frame lets us do a lot more than just pick the optimal hyperparameters. It lets us quickly wrangle with and visualize the results of the various combinations. Here are some ideas:
- Search among the top performers for the simplest model.
- Plot performance across the hyperparameter combinations.
- Save time by restricting the hypotheses before model fitting. For example, in a large data set, it’s practically pointless to try a small
minsplitand smallmaxdepth. In this case, before fitting the models, we canfilterthegsdata frame to exclude certain combinations.
Sign off
Thanks for reading and I hope this was useful for you.
For updates of recent blog posts, follow @drsimonj on Twitter, or email me atdrsimonjackson@gmail.com to get in touch.
If you’d like the code that produced this blog, check out the blogR GitHub repository.
转自:https://drsimonj.svbtle.com/grid-search-in-the-tidyverse
Grid search in the tidyverse的更多相关文章
- Comparing randomized search and grid search for hyperparameter estimation
Comparing randomized search and grid search for hyperparameter estimation Compare randomized search ...
- 3.2. Grid Search: Searching for estimator parameters
3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...
- How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras
Hyperparameter optimization is a big part of deep learning. The reason is that neural networks are n ...
- Grid Search学习
转自:https://www.cnblogs.com/ysugyl/p/8711205.html Grid Search:一种调参手段:穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性 ...
- grid search 超参数寻优
http://scikit-learn.org/stable/modules/grid_search.html 1. 超参数寻优方法 gridsearchCV 和 RandomizedSearchC ...
- scikit-learn:3.2. Grid Search: Searching for estimator parameters
參考:http://scikit-learn.org/stable/modules/grid_search.html GridSearchCV通过(蛮力)搜索參数空间(參数的全部可能组合).寻找最好的 ...
- [转载]Grid Search
[转载]Grid Search 初学机器学习,之前的模型都是手动调参的,效果一般.同学和我说他用了一个叫grid search的方法.可以实现自动调参,顿时感觉非常高级.吃饭的时候想调参的话最差不过也 ...
- grid search
sklearn.metrics.make_scorer(score_func, greater_is_better=True, needs_proba=False, needs_threshold=F ...
- Hackerrank - The Grid Search
https://www.hackerrank.com/challenges/the-grid-search/forum 今天碰见这题,看见难度是Moderate,觉得应该能半小时内搞定. 读完题目发现 ...
随机推荐
- mysql 免安装版 + sqlyog 安装 步骤 --- 发的有点晚
总有些朋友不会安装mysql,其实软件安装不是学习mysql的重点,基本上也就安装一次,工作后一般公司里也不会让你安装,如果非要安装,百度一下就行了.安装版本百度上有许多,下面就提供一个免安装版的步骤 ...
- 【原创101】Servlet精细笔记
一.什么Servlet? servlet 是运行在 Web 服务器中的小型 Java 程序(即:服务器端的小应用程序).servlet 通常通过 HTTP(超文本传输协议)接收和响应来自 Web 客户 ...
- 现代3D图形编程学习--opengl使用不同的缓存对象(译者添加)
现代3D图形编程学习系列翻译地址 http://www.cnblogs.com/grass-and-moon/category/920962.html opengl使用不同的缓存对象 在设置颜色一章中 ...
- T_SQL编程赋值、分支语句、循环
咱们在C#中会常用到赋值.循环.分支语句什么的 今天咱们来看下当初在C#用到的一点东西放到SQL中是怎么使用的 创建变量 在C#中创建一个值类型变量很简单 int a:这就可以了 SQL: decla ...
- Linux--线程安全与可重入函数的异同
线程安全 比如一个 ArrayList 类,在添加一个元素的时候,它可能会有两步来完成: 1. 在 Items[Size] 的位置存放此元素: 2. 增大 Size 的值. 在单线程运行的情况下,如果 ...
- js错误问题 The operation is insecure.
问题: 当我使用canvas的ctx.getImageData 方法时,js报错,错误是 The operation is insecure. 解决: 我使用ctx.getImageData获取can ...
- HubbleDotNet 最新绿色版,服务端免安装,基于eaglet 最后V1.2.8.9版本开发,bug修正,支持一键生成同步表
HubbleDotNet 是一个基于.net framework 的开源免费的全文搜索数据库组件.开源协议是 Apache 2.0.HubbleDotNet提供了基于SQL的全文检索接口,使用者只需会 ...
- ubutun 安装php7.1x
服务器ecs上本来跑了一套nginx+php5.5,由于新项目使用的是laravel5.4,所以不得不把php升级,在此记录下在此安装的过程和遇到的问题,总体来说还算顺利 cd /usr/local/ ...
- 在同一个系统上装两个不同版本的jdk,配置环境变量不起作用,jdk版本的切换问题
本人这台笔记本前面装了jdk8,现在准备用jdk7,我安装好了jdk7:把系统变量中的JAVA_HOME 改为 D:\java\jdk\jdk7\jdk1.7.0_67,Path 下添加如下变量,记得 ...
- MySQL 完整和增量备份与恢复
MySQL 完全备份与恢复 1.数据备份的重要性 在企业中数据的价值至关重要,数据保障了企业的业务的运行,因此数据的安全性及可靠性是运维的重中之重,任何数据的丢失都有可能会对企业产生严重的后果.造成数 ...