parallel grid search cross-validation using `crossvalidation`
[This article was first published on T. Moudiki's Webpage - R, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Install package ‘crossvalidation’
options(repos = c( techtonique = 'https://techtonique.r-universe.dev', CRAN = 'https://cloud.r-project.org')) install.packages("crossvalidation")
Import packages
library(crossvalidation) library(randomForest) library(microbenchmark)
Input data
set.seed(123) n <- 1000 ; p <- 10 X <- matrix(rnorm(n * p), n, p) y <- rnorm(n)
Random forest hyperparameters for a grid search
tuning_grid <- base::expand.grid(mtry = c(2, 3, 4), ntree = c(100, 200, 300)) n_params <- nrow(tuning_grid) print(tuning_grid)
Sequential and parallel execution of cross-validation on a tuning grid
n_cores <- 4
Sequential
f1 <- function() base::lapply(1:n_params, function(i) crossvalidation::crossval_ml( x = X, y = y, k = 5, repeats = 3, fit_func = randomForest::randomForest, predict_func = predict, packages = "randomForest", fit_params = list(mtry = tuning_grid[i, "mtry"], ntree = tuning_grid[i, "ntree"]) ))
Parallel 1
f2 <- function() parallel::mclapply(1:n_params, function(i) crossvalidation::crossval_ml( x = X, y = y, k = 5, repeats = 3, fit_func = randomForest::randomForest, predict_func = predict, packages = "randomForest", fit_params = list(mtry = tuning_grid[i, "mtry"], ntree = tuning_grid[i, "ntree"]) ), mc.cores=n_cores)
Parallel 2
f3 <- function() base::lapply(1:n_params, function(i) crossvalidation::crossval_ml( x = X, y = y, k = 5, repeats = 3, fit_func = randomForest::randomForest, predict_func = predict, packages = "randomForest", fit_params = list(mtry = tuning_grid[i, "mtry"], ntree = tuning_grid[i, "ntree"]), cl=n_cores ))
Check that the three functions return the same result
all.equal(f1(), f2()) all.equal(f2(), f3())
Timings for f1, f2, f3
(timings <- microbenchmark::microbenchmark(f1(), f2(), f3(), times = 10L))
Plot results:
boxplot(timings, xlab = "function")
print(sessionInfo()) R version 4.0.4 (2021-02-15) Platform: x86_64-apple-darwin17.0 (64-bit) Running under: macOS Big Sur 10.16 Matrix products: default LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib locale: [1] fr_FR.UTF-8/fr_FR.UTF-8/fr_FR.UTF-8/C/fr_FR.UTF-8/fr_FR.UTF-8 attached base packages: [1] stats graphics grDevices utils datasets methods base other attached packages: [1] microbenchmark_1.4-7 randomForest_4.6-14 crossvalidation_0.3.0 [4] foreach_1.5.1 forecast_8.14 httr_1.4.2 loaded via a namespace (and not attached): [1] Rcpp_1.0.6 urca_1.3-0 pillar_1.4.6 compiler_4.0.4 [5] iterators_1.0.12 tseries_0.10-47 tools_4.0.4 xts_0.12.1 [9] digest_0.6.25 jsonlite_1.7.2 nlme_3.1-152 lifecycle_0.2.0 [13] tibble_3.0.3 gtable_0.3.0 lattice_0.20-41 doSNOW_1.0.19 [17] pkgconfig_2.0.3 rlang_0.4.10 rstudioapi_0.11 curl_4.3 [21] parallel_4.0.4 dplyr_1.0.2 xml2_1.3.2 generics_0.0.2 [25] vctrs_0.3.4 lmtest_0.9-38 grid_4.0.4 nnet_7.3-15 [29] tidyselect_1.1.0 glue_1.4.2 R6_2.5.0 snow_0.4-3 [33] crossval_0.2.1 farver_2.0.3 ggplot2_3.3.3 purrr_0.3.4 [37] TTR_0.24.2 magrittr_1.5 codetools_0.2-18 scales_1.1.1 [41] ellipsis_0.3.1 quantmod_0.4.17 mime_0.9 timeDate_3043.102 [45] colorspace_1.4-1 fracdiff_1.5-1 quadprog_1.5-8 munsell_0.5.0 [49] crayon_1.3.4 zoo_1.8-8
To leave a comment for the author, please follow the link and comment on their blog: T. Moudiki's Webpage - R.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.