Grid search cross-validation using crossval

April 9, 2020
By

[This article was first published on T. Moudiki's Webpage - R, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

crossval is an R package which contains generic functions for cross-validation. Two weeks ago, I presented an example of time series cross-validation based on crossval. This week’s post is about cross-validation on a grid of hyperparameters. glmnet is used as statistical learning model for the demo, but it could be any other package of your choice.

Installing and loading the packages

Installing crossval from GitHub (in R console):

devtools::install_github("thierrymoudiki/crossval")

Loading packages:

library(glmnet)
library(crossval)

Load mtcars dataset

data("mtcars")
df <- mtcars[, c(1, 2, 3, 4, 6, 11)]
summary(df)

Create response and explanatory variables from mtcars dataset

X <- as.matrix(df[, -1]) # explanatory variables
y <- df$mpg # response

image-title-here

Grid of hyperparameters for glmnet

tuning_grid <- expand.grid(alpha = c(0, 0.5, 1),
                           lambda = c(0.01, 0.1, 1))
n_params <- nrow(tuning_grid)
print(tuning_grid)
##   alpha lambda
## 1   0.0   0.01
## 2   0.5   0.01
## 3   1.0   0.01
## 4   0.0   0.10
## 5   0.5   0.10
## 6   1.0   0.10
## 7   0.0   1.00
## 8   0.5   1.00
## 9   1.0   1.00

Grid search cross-validation

  • list of cross-validation results
  • 5-fold cross-validation (k)
  • repeated 3 times (repeats)
  • cross-validation of 80% of the data (p)
  • validation on the remaining 20%
cv_results <- lapply(1:n_params,
                     function(i)
                       crossval::crossval_ml(
                         x = X,
                         y = y,
                         k = 5,
                         repeats = 3,
                         p = 0.8,
                         fit_func = glmnet::glmnet,
                         predict_func = predict.glmnet,
                         packages = c("glmnet", "Matrix"),
                         fit_params = list(alpha = tuning_grid[i, "alpha"],
                                           lambda = tuning_grid[i, "lambda"])
                       ))
names(cv_results) <- paste0("params_set", 1:n_params)

Remarks are welcome.

print(cv_results)
## $params_set1
## $params_set1$folds
##                    repeat_1  repeat_2 repeat_3
## fold_training_1   2.7116571 3.4204585 2.970296
## fold_validation_1 1.1310676 2.1443185 2.038922
## fold_training_2   1.7335414 1.0317404 3.740119
## fold_validation_2 1.6528925 1.5592805 0.905873
## fold_training_3   2.9526843 4.4059576 3.063401
## fold_validation_3 2.4348686 0.9470344 1.227135
## fold_training_4   4.3206047 3.7097429 3.252773
## fold_validation_4 0.8305158 1.7408722 1.793542
## fold_training_5   1.3484699 1.9396528 1.322698
## fold_validation_5 1.4838844 1.8029411 1.288075
## 
## $params_set1$mean_training
## [1] 2.79492
## 
## $params_set1$mean_validation
## [1] 1.532082
## 
## $params_set1$sd_training
## [1] 1.089414
## 
## $params_set1$sd_validation
## [1] 0.4773198
## 
## $params_set1$median_training
## [1] 2.970296
## 
## $params_set1$median_validation
## [1] 1.559281
## 
## 
## $params_set2
## $params_set2$folds
##                    repeat_1  repeat_2  repeat_3
## fold_training_1   2.6942232 3.4034435 2.9509207
## fold_validation_1 1.1283288 2.1212168 2.0922106
## fold_training_2   1.7071382 1.0236950 3.7337454
## fold_validation_2 1.6183054 1.5395739 0.9289049
## fold_training_3   2.9572493 4.3913568 3.0347475
## fold_validation_3 2.4458845 0.9670278 1.2149724
## fold_training_4   4.3683721 3.6924562 3.2383772
## fold_validation_4 0.8786582 1.7168194 1.7714379
## fold_training_5   1.3451998 1.9338217 1.3169037
## fold_validation_5 1.4832682 1.7971002 1.2850981
## 
## $params_set2$mean_training
## [1] 2.78611
## 
## $params_set2$mean_validation
## [1] 1.532587
## 
## $params_set2$sd_training
## [1] 1.093607
## 
## $params_set2$sd_validation
## [1] 0.470716
## 
## $params_set2$median_training
## [1] 2.957249
## 
## $params_set2$median_validation
## [1] 1.539574
## 
## 
## $params_set3
## $params_set3$folds
##                    repeat_1 repeat_2  repeat_3
## fold_training_1   2.6762742 3.385479 2.9318273
## fold_validation_1 1.1267161 2.094206 2.1505220
## fold_training_2   1.6851155 1.017365 3.7272127
## fold_validation_2 1.5972579 1.519639 0.9543918
## fold_training_3   2.9614653 4.376096 3.0052024
## fold_validation_3 2.4567021 0.989157 1.2033089
## fold_training_4   4.4107761 3.674386 3.2273064
## fold_validation_4 0.9223574 1.691938 1.7506447
## fold_training_5   1.3421543 1.928040 1.3124042
## fold_validation_5 1.4833113 1.791768 1.2834694
## 
## $params_set3$mean_training
## [1] 2.777407
## 
## $params_set3$mean_validation
## [1] 1.534359
## 
## $params_set3$sd_training
## [1] 1.096777
## 
## $params_set3$sd_validation
## [1] 0.4656268
## 
## $params_set3$median_training
## [1] 2.961465
## 
## $params_set3$median_validation
## [1] 1.519639
## 
## 
## $params_set4
## $params_set4$folds
##                   repeat_1  repeat_2 repeat_3
## fold_training_1   2.582406 3.2605565 2.864273
## fold_validation_1 1.168777 1.9268152 2.078255
## fold_training_2   1.650031 0.8984717 3.686839
## fold_validation_2 1.482450 1.4721014 1.017757
## fold_training_3   2.708588 4.2802020 2.939830
## fold_validation_3 2.235334 1.0667584 1.204211
## fold_training_4   4.466894 3.5879682 3.081803
## fold_validation_4 1.004771 1.5646289 1.570059
## fold_training_5   1.326640 1.9144285 1.380771
## fold_validation_5 1.509964 1.7885983 1.335888
## 
## $params_set4$mean_training
## [1] 2.708647
## 
## $params_set4$mean_validation
## [1] 1.495091
## 
## $params_set4$sd_training
## [1] 1.086611
## 
## $params_set4$sd_validation
## [1] 0.3817352
## 
## $params_set4$median_training
## [1] 2.864273
## 
## $params_set4$median_validation
## [1] 1.48245
## 
## 
## $params_set5
## $params_set5$folds
##                    repeat_1  repeat_2 repeat_3
## fold_training_1   2.5706795 3.1103749 2.803793
## fold_validation_1 1.2068333 1.7463001 2.075065
## fold_training_2   1.5011386 0.8148667 3.688756
## fold_validation_2 1.3987359 1.3301017 1.001877
## fold_training_3   2.7010045 4.2517543 2.747419
## fold_validation_3 2.2688901 1.1557910 1.178889
## fold_training_4   4.4448265 3.4750530 3.016761
## fold_validation_4 0.9854596 1.4860992 1.437257
## fold_training_5   1.3487225 1.8742370 1.295727
## fold_validation_5 1.5372721 1.7800714 1.311683
## 
## $params_set5$mean_training
## [1] 2.643008
## 
## $params_set5$mean_validation
## [1] 1.460022
## 
## $params_set5$sd_training
## [1] 1.093856
## 
## $params_set5$sd_validation
## [1] 0.3720159
## 
## $params_set5$median_training
## [1] 2.747419
## 
## $params_set5$median_validation
## [1] 1.398736
## 
## 
## $params_set6
## $params_set6$folds
##                    repeat_1  repeat_2  repeat_3
## fold_training_1   2.5990228 2.9781640 2.7493269
## fold_validation_1 1.2121089 1.5799814 2.0848477
## fold_training_2   1.4225076 0.7604152 3.6906583
## fold_validation_2 1.4216262 1.2543630 0.9884764
## fold_training_3   2.7409312 4.2492745 2.7175981
## fold_validation_3 2.3240529 1.1598608 1.1607340
## fold_training_4   4.4339739 3.4654770 3.0117350
## fold_validation_4 0.9800525 1.4991208 1.4168583
## fold_training_5   1.3765304 1.8415788 1.3257447
## fold_validation_5 1.5496021 1.8006454 1.3220442
## 
## $params_set6$mean_training
## [1] 2.624196
## 
## $params_set6$mean_validation
## [1] 1.450292
## 
## $params_set6$sd_training
## [1] 1.097017
## 
## $params_set6$sd_validation
## [1] 0.3811666
## 
## $params_set6$median_training
## [1] 2.740931
## 
## $params_set6$median_validation
## [1] 1.416858
## 
## 
## $params_set7
## $params_set7$folds
##                   repeat_1 repeat_2 repeat_3
## fold_training_1   2.698210 2.885301 2.455576
## fold_validation_1 1.551401 1.704756 1.716643
## fold_training_2   1.783057 1.028166 3.528652
## fold_validation_2 1.688929 1.457255 1.192856
## fold_training_3   2.635762 3.951937 2.764754
## fold_validation_3 2.325906 1.361088 1.478338
## fold_training_4   4.383367 3.622788 2.966129
## fold_validation_4 1.262743 1.758041 1.628874
## fold_training_5   1.520805 1.968637 1.384429
## fold_validation_5 1.747330 2.061490 1.586987
## 
## $params_set7$mean_training
## [1] 2.638505
## 
## $params_set7$mean_validation
## [1] 1.634842
## 
## $params_set7$sd_training
## [1] 0.9764259
## 
## $params_set7$sd_validation
## [1] 0.2898281
## 
## $params_set7$median_training
## [1] 2.69821
## 
## $params_set7$median_validation
## [1] 1.628874
## 
## 
## $params_set8
## $params_set8$folds
##                   repeat_1 repeat_2 repeat_3
## fold_training_1   2.966475 2.806465 1.737976
## fold_validation_1 1.692210 1.804410 1.498461
## fold_training_2   1.392634 1.104673 3.578175
## fold_validation_2 2.068470 1.499582 1.163872
## fold_training_3   2.684285 3.930335 2.611488
## fold_validation_3 2.543810 1.441189 1.498748
## fold_training_4   4.269152 3.760451 3.327202
## fold_validation_4 1.381628 2.067037 2.049550
## fold_training_5   1.771081 2.323059 1.777073
## fold_validation_5 1.946920 2.405880 1.787871
## 
## $params_set8$mean_training
## [1] 2.669368
## 
## $params_set8$mean_validation
## [1] 1.789976
## 
## $params_set8$sd_training
## [1] 0.9775324
## 
## $params_set8$sd_validation
## [1] 0.3908084
## 
## $params_set8$median_training
## [1] 2.684285
## 
## $params_set8$median_validation
## [1] 1.787871
## 
## 
## $params_set9
## $params_set9$folds
##                   repeat_1 repeat_2 repeat_3
## fold_training_1   3.254495 2.789325 1.094701
## fold_validation_1 1.878893 1.938494 1.581836
## fold_training_2   1.546198 1.179068 3.647095
## fold_validation_2 2.495703 1.623228 1.219294
## fold_training_3   2.478050 3.922693 2.414970
## fold_validation_3 2.594489 1.592061 1.575155
## fold_training_4   4.171757 3.904126 3.695321
## fold_validation_4 1.754051 2.395155 2.455114
## fold_training_5   2.053809 2.660005 2.127650
## fold_validation_5 2.177259 2.741885 2.012704
## 
## $params_set9$mean_training
## [1] 2.729284
## 
## $params_set9$mean_validation
## [1] 2.002355
## 
## $params_set9$sd_training
## [1] 1.014183
## 
## $params_set9$sd_validation
## [1] 0.454853
## 
## $params_set9$median_training
## [1] 2.660005
## 
## $params_set9$median_validation
## [1] 1.938494

Note: I am currently looking for a gig. You can hire me on Malt or send me an email: thierry dot moudiki at pm dot me. I can do descriptive statistics, data preparation, feature engineering, model calibration, training and validation, and model outputs’ interpretation. I am fluent in Python, R, SQL, Microsoft Excel, Visual Basic (among others) and French. My résumé? Here!

To leave a comment for the author, please follow the link and comment on their blog: T. Moudiki's Webpage - R.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Search R-bloggers

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)