Machine learning with {tidymodels}

[This article was first published on Econometrics and Free Software, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Intro: what is {tidymodels}

I have already written about {tidymodels} in the past but since then, the {tidymodels} meta-package has evolved quite a lot. If you don’t know what {tidymodels} is, it is a suite of packages that make machine learning with R a breeze. R has many packages for machine learning, each with their own syntax and function arguments. {tidymodels} aims at providing an unified interface which allows data scientists to focus on the problem they’re trying to solve, instead of wasting time with learning package specificities.

The packages included in {tidymodels} are:

  • {parsnip} for model definition
  • {recipes} for data preprocessing and feature engineering
  • {rsample} to resample data (useful for cross-validation)
  • {yardstick} to evaluate model performance
  • {dials} to define tuning parameters of your models
  • {tune} for model tuning
  • {workflows} which allows you to bundle everything together and train models easily

There are some others, but I will not cover these. This is a lot of packages, and you might be worried of getting lost; however, in practice I noticed that loading {tidymodels} and then using the functions I needed was good enough. Only rarely did I need to know from which package a certain function came, and the more you use these, the better you know them, obviously. Before continuing, one final and important note: these packages are still in heavy development, so you might not want to use them in production yet. I don’t know how likely it is that the api still evolves, but my guess is that it is likely. However, even though it might be a bit early to use these packages for production code, I think it is important to learn about them as soon as possible and see what is possible with them.

As I will show you, these packages do make the process of training machine learning models a breeze, and of course they integrate very well with the rest of the {tidyverse} packages. The problem we’re going to tackle is to understand which variables play an important role in the probability of someone looking for a job. I’ll use Eustat’s microdata, which I already discussed in my previous blog post. The dataset can be downloaded from here, and is called Population with relation to activity (PRA).

The problem at hand

The dataset contains information on residents from the Basque country, and focuses on their labour supply. Thus, we have information on how many hours people work a week, if they work, in which industry, what is their educational attainment and whether they’re looking for a job. The first step, as usual, is to load the data and required packages:

library(tidyverse)
library(tidymodels)
library(readxl)
library(naniar)
library(janitor)
library(furrr)

list_data <- Sys.glob("~/Documents/b-rodrigues.github.com/content/blog/MICRO*.csv")

dataset <- map(list_data, read_csv2) %>%
  bind_rows()

dictionary <- read_xlsx("~/Documents/b-rodrigues.github.com/content/blog/Microdatos_PRA_2019/diseño_registro_microdatos_pra.xlsx", sheet="Valores",
                        col_names = FALSE)

col_names <- dictionary %>%
  filter(!is.na(...1)) %>%
  dplyr::select(1:2)

english <- readRDS("~/Documents/b-rodrigues.github.com/content/blog/english_col_names.rds")

col_names$english <- english

colnames(dataset) <- col_names$english

dataset <- janitor::clean_names(dataset)

Let’s take a look at the data:

head(dataset)
## # A tibble: 6 x 33
##   household_number survey_year reference_quart… territory capital   sex
##              <dbl>       <dbl>            <dbl> <chr>       <dbl> <dbl>
## 1                1        2019                1 48              9     6
## 2                1        2019                1 48              9     1
## 3                2        2019                1 48              1     1
## 4                2        2019                1 48              1     6
## 5                2        2019                1 48              1     6
## 6                2        2019                1 48              1     1
## # … with 27 more variables: place_of_birth <dbl>, age <chr>, nationality <dbl>,
## #   level_of_studies_completed <dbl>, ruled_teaching_system <chr>,
## #   occupational_training <chr>, retirement_situation <dbl>,
## #   homework_situation <dbl>, part_time_employment <dbl>,
## #   short_time_cause <dbl>, job_search <chr>, search_reasons <dbl>,
## #   day_searched <dbl>, make_arrangements <chr>, search_form <chr>,
## #   search_months <dbl>, availability <chr>,
## #   relationship_with_the_activity <dbl>,
## #   relationship_with_the_activity_2 <chr>, main_occupation <dbl>,
## #   main_activity <chr>, main_professional_situation <dbl>,
## #   main_institutional_sector <dbl>, type_of_contract <dbl>, hours <dbl>,
## #   relationship <dbl>, elevator <dbl>

There are many columns, most of them are categorical variables and unfortunately the levels in the data are only some non-explicit codes. The excel file I have loaded, which I called dictionary contains the codes and their explanation. I kept the file opened while I was working, especially for missing values imputation. Indeed, there are missing values in the data, and one should always try to understand why before blindly imputing them. Indeed, there might be a very good reason why data might be missing for a particular column. For instance, if children are also surveyed, they would have an NA in the, say, main_occupation column which gives the main occupation of the surveyed person. This might seem very obvious, but sometimes these reasons are not so obvious at all. You should always go back with such questions to the data owners/producers, because if not, you will certainly miss something very important. Anyway, the way I tackled this issue was by looking at the variables with missing data and checking two-way tables with other variables. For instance, to go back to my example from before, I would take a look at the two-way frequency table between age and main_occupation. If all the missing values from main_occupation where only for people 16 or younger, then it would be quite safe to assume that I was right, and I could recode these NAs in main_occupation to "without occupation" for instance. I’ll spare you all this exploration, and go straight to the data cleaning:

dataset <- dataset %>%
  mutate(main_occupation2 = ifelse(is.na(main_occupation),
                                   "without_occupation",
                                   main_occupation))

dataset <- dataset %>%
  mutate(main_professional_situation2 = ifelse(is.na(main_professional_situation),
                                               "without_occupation",
                                               main_professional_situation))

# People with missing hours are actually not working, so I put them to 0
dataset <- dataset %>%
  mutate(hours = ifelse(is.na(hours), 0, hours))

# Short time gives the reason why people are working less hours than specified in their contract
dataset <- dataset %>%
  mutate(short_time_cause = ifelse(hours == 0 | is.na(short_time_cause), 
                                   "without_occupation",
                                   short_time_cause))

dataset <- dataset %>%
  mutate(type_of_contract = ifelse(is.na(type_of_contract),
                                   "other_contract",
                                   type_of_contract))

Let’s now apply some further cleaning:

pra <- dataset %>%
  filter(age %in% c("04", "05", "06", "07", "08", "09", "10", "11", "12", "13")) %>%
  filter(retirement_situation == 4) %>%    
  filter(!is.na(job_search)) %>%  
  select(capital, sex, place_of_birth, age, nationality, level_of_studies_completed,
         occupational_training, job_search, main_occupation2, type_of_contract,
         hours, short_time_cause, homework_situation,
         main_professional_situation2) %>%
  mutate_at(.vars = vars(-hours), .funs=as.character) %>%
  mutate(job_search = as.factor(job_search))

I only keep people that are not retired and of ages where they could work. I remove rows where job_search, the target, is missing, mutate all variables but hours to character and job_search to factor. At first, I made every categorical column a factor but I got problems for certain models. I think the issue came from the recipe that I defined (I’ll talk about it below), but the problem was resolved if categorical variables were defined as character variables. However, for certain models, the target (I think it was xgboost) needs to be a factor variable for classification problems.

Let’s take a look at the data and check if any more data is missing:

str(pra)
## Classes 'spec_tbl_df', 'tbl_df', 'tbl' and 'data.frame': 29083 obs. of  14 variables:
##  $ capital                     : chr  "9" "9" "1" "1" ...
##  $ sex                         : chr  "6" "1" "1" "6" ...
##  $ place_of_birth              : chr  "1" "1" "1" "1" ...
##  $ age                         : chr  "09" "09" "11" "10" ...
##  $ nationality                 : chr  "1" "1" "1" "1" ...
##  $ level_of_studies_completed  : chr  "1" "2" "3" "3" ...
##  $ occupational_training       : chr  "N" "N" "N" "N" ...
##  $ job_search                  : Factor w/ 2 levels "N","S": 1 1 1 1 1 1 1 1 1 1 ...
##  $ main_occupation2            : chr  "5" "7" "3" "2" ...
##  $ type_of_contract            : chr  "1" "other_contract" "other_contract" "1" ...
##  $ hours                       : num  36 40 40 40 0 0 22 38 40 0 ...
##  $ short_time_cause            : chr  "2" "2" "2" "2" ...
##  $ homework_situation          : chr  "1" "2" "2" "2" ...
##  $ main_professional_situation2: chr  "4" "2" "3" "4" ...
vis_miss(pra)

The final dataset contains 29083 observations. Look’s like we’re good to go.

Setting up the training: resampling

In order to properly train a model, one needs to split the data into two: a part for trying out models with different configuration of hyper-parameters, and another part for final evaluation of the model. This is achieved with rsample::initial_split():

pra_split <- initial_split(pra, prop = 0.9)

pra_split now contains a training set and a testing set. We can get these by using the rsample::training() and rsample::testing() functions:

pra_train <- training(pra_split)
pra_test <- testing(pra_split)

We can’t stop here though. First we need to split the training set further, in order to perform cross validation. Cross validation will allow us to select the best model; by best I mean a model that has a good hyper-parameter configuration, enabling the model to generalize well to unseen data. I do this by creating 10 splits from the training data (I won’t touch the testing data up until the very end. This testing data is thus sometimes called the holdout set as well):

pra_cv_splits <- vfold_cv(pra_train, v = 10)

Let’s take a look at this object:

pra_cv_splits
## #  10-fold cross-validation 
## # A tibble: 10 x 2
##    splits               id    
##    <named list>         <chr> 
##  1 <split [23.6K/2.6K]> Fold01
##  2 <split [23.6K/2.6K]> Fold02
##  3 <split [23.6K/2.6K]> Fold03
##  4 <split [23.6K/2.6K]> Fold04
##  5 <split [23.6K/2.6K]> Fold05
##  6 <split [23.6K/2.6K]> Fold06
##  7 <split [23.6K/2.6K]> Fold07
##  8 <split [23.6K/2.6K]> Fold08
##  9 <split [23.6K/2.6K]> Fold09
## 10 <split [23.6K/2.6K]> Fold10

Preprocessing the data

I have already pre-processed the missing values in the dataset, so there is not much more that I can do. I will simply create dummy variables out of the categorical variables using step_dummy():

preprocess <- recipe(job_search ~ ., data = pra) %>%
  step_dummy(all_predictors())

preprocess is a recipe that defines the transformations that must be applied to the training data before fitting. In this case there is only one step; transforming all the predictors into dummies (hours is a numeric variable and will be ignored by this step). The recipe also defines the formula that will be fitted by the models, job_search ~ ., and takes data as a further argument. This is only to give the data frame specification to recipe(): it could even be an empty data frame with the right column names and types. This is why I give it the original data pra and not the training set pra_train. Because this recipe is very simple, it could be applied to the original raw data pra and then I could do the split into training and testing set, as well as further splitting the training set into 10 cross-validation sets. However, this is not the recommended way of applying pre-processing steps. Pre-processing needs to happen inside the cross-validation loop, not outside of it. Why? Suppose that you are normalizing a numeric variable, meaning, substracting its mean from it and dividing by its standard deviation. If you do this operation outside of cross-validation, and even worse, before splitting the data into training and testing set, you will be leaking information from the testing set into the training set. The mean will contain information from the testing set, which will be picked up by the model. It is much better and “realistic” to first split the data and then apply the pre-processing (remember that hiding the test set from the model is supposed to simulate the fact that new, completely unseen data, is thrown at your model once it’s put into production). The same logic applies to cross-validation splits; each split contains now also a training and a testing set (which I will be calling analysis and assessment sets, following {tidymodels}’s author, Max Kuhn) and thus the pre-processing needs to be applied inside the cross-validation loop, meaning that the analysis set will be processed on the fly.

Model definition

We come now to the very interesting part: model definition. With {parsnip}, another {tidymodels} package, defining models is always the same, regardless of the underlying package doing the heavy lifting. For instance, to define a logistic regression one would simply write:

# logistic regression 
logit_tune_pra <- logistic_reg() %>%
  set_engine("glm")

This defines a standard logistic regression, powered by the glm() engine or function. The way to do this in vanilla R would be :

glm(y ~ ., data = mydata, family = "binomial")

The difference here is that the formula is contained in the glm() function; in our case it is contained in the recipe, which is why I don’t repeat it in the model definition above. You might wonder what the added value of using {tidymodels} for this is. Well, suppose now that I would like to run a logistic regression but with regularization. I would use {glmnet} for this but would need to know the specific syntax of glmnet() which, as you will see, is very different than the one for glm():

  glmnet(x_vars[train,], y_var[train], alpha = 1, lambda = 1.6)

glmnet(), unlike glm(), does not use a formula as an input, but two matrices, one for the design matrix, and another for the target variable. Using {parsnip}, however, I simply need to change the engine from "glm" to "glmnet":

# logistic regression 
logit_tune_pra <- logistic_reg() %>%
  set_engine("glmnet")

This makes things much simpler as now users only need to learn how to use {parsnip}. However, it is of course still important to read the documentation of the original packages, because it is were hyper-parameters are discussed. Another advantage of {parsnip} is that the same words are used to speak of the same hyper-parameters . For instance for tree-based methods, the number of trees is sometimes ntree then in another package num_trees, and is again different in yet another package. In {parsnip}’s interface for tree-based methods, this parameter is simply called tree. Users can fix the value of hyper-parameters directly by passing values to, say, tree (as in "tree" = 200), or they can tune these hyper-parameters. To do so, one needs to tag them, like so:

# logistic regression 
logit_tune_pra <- logistic_reg(penalty = tune(), mixture = tune()) %>%
  set_engine("glmnet")

This defines logit_tune_pra with 2 hyper-parameters that must be tuned using cross-validation, the penalty and the amount of mixture between penalties (this is for elasticnet regularization).

Now, I will define 5 different models, with different hyper-parameters to tune, and I will also define a grid of hyper-parameters of size 10 for each model. This means that I will train these 5 models 10 times, each time with a different hyper-parameter configuration. To define the grid, I use the grid_max_entropy() function from the {dials} package. This creates a grid with points that are randomly drawn from the parameter space in a way that ensures that the combination we get covers the whole space, or at least are not too far away from any portion of the space. Of course, the more configuration you try, the better, but the longer the training will run.

# Logistic regression
logit_tune_pra <- logistic_reg(penalty = tune(), mixture = tune()) %>%
  set_engine("glmnet")

# Hyperparameter grid
logit_grid <- logit_tune_pra %>%
  parameters() %>%
  grid_max_entropy(size = 10)

# Workflow bundling every step 
logit_wflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(logit_tune_pra)

# random forest
rf_tune_pra <- rand_forest(mtry = tune(), trees = tune()) %>%
  set_engine("ranger") %>%
  set_mode("classification")

rf_grid <- rf_tune_pra %>%
  parameters() %>%
  finalize(select(pra, -job_search)) %>%  
  grid_max_entropy(size = 10)

rf_wflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(rf_tune_pra)

# mars model
mars_tune_pra <- mars(num_terms = tune(), prod_degree = 2, prune_method = tune()) %>%
  set_engine("earth") %>%
  set_mode("classification")

mars_grid <- mars_tune_pra %>%
  parameters() %>%
  grid_max_entropy(size = 10)

mars_wflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(mars_tune_pra)

#boosted trees
boost_tune_pra <- boost_tree(mtry = tune(), tree = tune(),
                             learn_rate = tune(), tree_depth = tune()) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

boost_grid <- boost_tune_pra %>%
  parameters() %>%
  finalize(select(pra, -job_search)) %>%  
  grid_max_entropy(size = 10)

boost_wflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(boost_tune_pra)

#neural nets
keras_tune_pra <- mlp(hidden_units = tune(), penalty = tune(), activation = "relu") %>%
  set_engine("keras") %>%
  set_mode("classification")

keras_grid <- keras_tune_pra %>%
  parameters() %>%
  grid_max_entropy(size = 10)

keras_wflow <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(keras_tune_pra)

For each model, I defined three objects; the model itself, for instance keras_tune_pra, then a grid of hyper-parameters, and finally a workflow. To define the grid, I need to extract the parameters to tune using the parameters() function, and for tree based methods, I also need to use finalize() to set the mtry parameter. This is because mtry depends on the dimensions of the data (the value of mtry cannot be larger than the number of features), so I need to pass on this information to…well, finalize the grid. Then I can choose the size of the grid and how I want to create it (randomly, or using max entropy, or regularly spaced…). A workflow bundles the pre-processing and the model definition together, and makes fitting the model very easy. Workflows make it easy to run the pre-processing inside the cross-validation loop. Workflow objects can be passed to the fitting function, as we shall see in the next section.

Fitting models with {tidymodels}

Fitting one model with {tidymodels} is quite easy:

fitted_model <- fit(model_formula, data = data_train)

and that’s it. If you define a workflow, which bundles pre-processing and model definition in one package, you need to pass it to fit() as well:

fitted_wflow <- fit(model_wflow, data = data_train)

However, a single call to fit does not perform cross-validation. This simply trains the model on the training data, and that’s it. To perform cross validation, you can use either fit_resamples():

fitted_resamples <- fit_resamples(model_wflow,
                               resamples = my_cv_splits,
                               control = control_resamples(save_pred = TRUE))

or tune_grid():

tuned_model <- tune_grid(model_wflow,
                         resamples = my_cv_splits,
                         grid = my_grid,
                         control = control_resamples(save_pred = TRUE))

As you probably guessed it, fit_resamples() does not perform tuning; it simply fits a model specification (without varying hyper-parameters) to all the analysis sets contained in the my_cv_splits object (which contains the resampled training data for cross-validation), while tune_grid() does the same, but allows for varying hyper-parameters.

We thus are going to use tune_grid() to fit our models and perform hyper-paramater tuning. However, since I have 5 models and 5 grids, I’ll be using map2() for this. If you’re not familiar with map2(), here’s a quick example:

map2(c(1, 1, 1), c(2,2,2), `+`)
## [[1]]
## [1] 3
## 
## [[2]]
## [1] 3
## 
## [[3]]
## [1] 3

map2() maps the +() function to each element of both vectors successively. I’m going to use this to map the tune_grid() function to a list of models and a list of grids. But because this is going to take some time to run, and because I have an AMD Ryzen 5 1600X processor with 6 physical cores and 12 logical cores, I’ll by running this in parallel using furrr::future_map2().

furrr::future_map2() will run one model per core, and the way to do it is to simply define how many cores I want to use, then replace map2() in my code by future_map2():

wflow_list <- list(logit_wflow, rf_wflow, mars_wflow, boost_wflow, keras_wflow)
grid_list <- list(logit_grid, rf_grid, mars_grid, boost_grid, keras_grid)

plan(multiprocess, workers = 6)

trained_models_list <- future_map2(.x = wflow_list,
                                   .y = grid_list,
                                   ~tune_grid(.x , resamples = pra_cv_splits, grid = .y))

Running this code took almost 3 hours. In the end, here is the result:

trained_models_list
## [[1]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[2]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[3]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[4]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 7]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 7]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 7]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 7]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 7]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 7]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 7]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 7]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 7]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 7]> <tibble [1 × 1]>
## 
## [[5]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>

I now have a list of 5 tibbles containing the analysis/assessment splits, the id identifying the cross-validation fold, a list-column containing information on model performance for that given split and some notes (if everything goes well, notes are empty). Let’s take a look at the column .metrics of the first model and for the first fold:

trained_models_list[[1]]$.metrics[[1]]
## # A tibble: 20 x 5
##     penalty mixture .metric  .estimator .estimate
##       <dbl>   <dbl> <chr>    <chr>          <dbl>
##  1 4.25e- 3  0.0615 accuracy binary         0.906
##  2 4.25e- 3  0.0615 roc_auc  binary         0.895
##  3 6.57e-10  0.0655 accuracy binary         0.908
##  4 6.57e-10  0.0655 roc_auc  binary         0.897
##  5 1.18e- 6  0.167  accuracy binary         0.908
##  6 1.18e- 6  0.167  roc_auc  binary         0.897
##  7 2.19e-10  0.371  accuracy binary         0.907
##  8 2.19e-10  0.371  roc_auc  binary         0.897
##  9 2.73e- 1  0.397  accuracy binary         0.885
## 10 2.73e- 1  0.397  roc_auc  binary         0.5  
## 11 1.72e- 6  0.504  accuracy binary         0.907
## 12 1.72e- 6  0.504  roc_auc  binary         0.897
## 13 1.25e- 9  0.633  accuracy binary         0.907
## 14 1.25e- 9  0.633  roc_auc  binary         0.897
## 15 6.62e- 6  0.880  accuracy binary         0.907
## 16 6.62e- 6  0.880  roc_auc  binary         0.897
## 17 6.00e- 1  0.899  accuracy binary         0.885
## 18 6.00e- 1  0.899  roc_auc  binary         0.5  
## 19 4.57e-10  0.989  accuracy binary         0.907
## 20 4.57e-10  0.989  roc_auc  binary         0.897

This shows how the 10 different configurations of the elasticnet model performed. To see how the model performed on the second fold:

trained_models_list[[1]]$.metrics[[2]]
## # A tibble: 20 x 5
##     penalty mixture .metric  .estimator .estimate
##       <dbl>   <dbl> <chr>    <chr>          <dbl>
##  1 4.25e- 3  0.0615 accuracy binary         0.913
##  2 4.25e- 3  0.0615 roc_auc  binary         0.874
##  3 6.57e-10  0.0655 accuracy binary         0.913
##  4 6.57e-10  0.0655 roc_auc  binary         0.877
##  5 1.18e- 6  0.167  accuracy binary         0.913
##  6 1.18e- 6  0.167  roc_auc  binary         0.878
##  7 2.19e-10  0.371  accuracy binary         0.913
##  8 2.19e-10  0.371  roc_auc  binary         0.878
##  9 2.73e- 1  0.397  accuracy binary         0.901
## 10 2.73e- 1  0.397  roc_auc  binary         0.5  
## 11 1.72e- 6  0.504  accuracy binary         0.913
## 12 1.72e- 6  0.504  roc_auc  binary         0.878
## 13 1.25e- 9  0.633  accuracy binary         0.913
## 14 1.25e- 9  0.633  roc_auc  binary         0.878
## 15 6.62e- 6  0.880  accuracy binary         0.913
## 16 6.62e- 6  0.880  roc_auc  binary         0.878
## 17 6.00e- 1  0.899  accuracy binary         0.901
## 18 6.00e- 1  0.899  roc_auc  binary         0.5  
## 19 4.57e-10  0.989  accuracy binary         0.913
## 20 4.57e-10  0.989  roc_auc  binary         0.878

Hyper-Parameters are the same; it is only the cross validation fold that is different. To get the best performing model from such objects you can use show_best() which will extract the best performing models across all the cross validation folds:

show_best(trained_models_list[[1]], metric = "accuracy")
## # A tibble: 5 x 7
##    penalty mixture .metric  .estimator  mean     n std_err
##      <dbl>   <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1 6.57e-10  0.0655 accuracy binary     0.916    10 0.00179
## 2 1.18e- 6  0.167  accuracy binary     0.916    10 0.00180
## 3 1.72e- 6  0.504  accuracy binary     0.916    10 0.00182
## 4 4.57e-10  0.989  accuracy binary     0.916    10 0.00181
## 5 6.62e- 6  0.880  accuracy binary     0.916    10 0.00181

This shows the 5 best configurations for elasticnet when looking at accuracy. Now how to get the best performing elasticnet regression, random forest, boosted trees, etc? Easy, using map():

map(trained_models_list, show_best, metric = "accuracy")
## [[1]]
## # A tibble: 5 x 7
##    penalty mixture .metric  .estimator  mean     n std_err
##      <dbl>   <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1 6.57e-10  0.0655 accuracy binary     0.916    10 0.00179
## 2 1.18e- 6  0.167  accuracy binary     0.916    10 0.00180
## 3 1.72e- 6  0.504  accuracy binary     0.916    10 0.00182
## 4 4.57e-10  0.989  accuracy binary     0.916    10 0.00181
## 5 6.62e- 6  0.880  accuracy binary     0.916    10 0.00181
## 
## [[2]]
## # A tibble: 5 x 7
##    mtry trees .metric  .estimator  mean     n std_err
##   <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
## 1    13  1991 accuracy binary     0.929    10 0.00172
## 2    13  1180 accuracy binary     0.929    10 0.00168
## 3    12   285 accuracy binary     0.928    10 0.00168
## 4     8  1567 accuracy binary     0.927    10 0.00171
## 5     8   647 accuracy binary     0.927    10 0.00191
## 
## [[3]]
## # A tibble: 5 x 7
##   num_terms prune_method .metric  .estimator  mean     n std_err
##       <int> <chr>        <chr>    <chr>      <dbl> <int>   <dbl>
## 1         5 backward     accuracy binary     0.904    10 0.00186
## 2         5 forward      accuracy binary     0.902    10 0.00185
## 3         4 exhaustive   accuracy binary     0.901    10 0.00167
## 4         4 seqrep       accuracy binary     0.901    10 0.00167
## 5         2 backward     accuracy binary     0.896    10 0.00209
## 
## [[4]]
## # A tibble: 5 x 9
##    mtry trees tree_depth learn_rate .metric  .estimator  mean     n std_err
##   <int> <int>      <int>      <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1    12  1245         12   7.70e- 2 accuracy binary     0.929    10 0.00175
## 2     1   239          8   8.23e- 2 accuracy binary     0.927    10 0.00186
## 3     1   835         14   8.53e-10 accuracy binary     0.913    10 0.00232
## 4     4  1522         12   2.22e- 5 accuracy binary     0.896    10 0.00209
## 5     6   313          2   1.21e- 8 accuracy binary     0.896    10 0.00209
## 
## [[5]]
## # A tibble: 5 x 7
##   hidden_units  penalty .metric  .estimator  mean     n std_err
##          <int>    <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1           10 3.07e- 6 accuracy binary     0.917    10 0.00209
## 2            6 1.69e-10 accuracy binary     0.917    10 0.00216
## 3            4 2.32e- 7 accuracy binary     0.916    10 0.00194
## 4            7 5.52e- 5 accuracy binary     0.916    10 0.00163
## 5            8 1.13e- 9 accuracy binary     0.916    10 0.00173

Now, we need to test these models on the holdout set, but this post is already quite long. In the next blog post, I will retrain the top best performing models for each type of model and see how they fare against the holdout set. I’ll be also looking at explainability, so stay tuned!

Hope you enjoyed! If you found this blog post useful, you might want to follow me on twitter for blog post updates and watch my youtube channel. If you want to support my blog and channel, you could buy me an espresso or paypal.me, or buy my ebook on Leanpub.

Buy me an EspressoBuy me an Espresso

To leave a comment for the author, please follow the link and comment on their blog: Econometrics and Free Software.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)