Site icon R-bloggers

Resampling Solution

[This article was first published on mlr-org, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
< section id="goal" class="level1">

Goal

You will learn how to estimate the model performance with mlr3 using resampling techniques such as 5-fold cross-validation. Additionally, you will compare k-NN model against a logistic regression model.

< section id="german-credit-data" class="level1">

German Credit Data

We work with the German credit data. You can either manually create the corresponding mlr3 task as we did before or use a pre-defined task which is already included in the mlr3 package (you can look at the output of as.data.table(mlr_tasks) to see which other pre-defined tasks that can be used to play around are included in the mlr3 package).

library(mlr3verse)
Lade nötiges Paket: mlr3
task = tsk("german_credit")
task 
<TaskClassif:german_credit> (1000 x 21): German Credit
* Target: credit_risk
* Properties: twoclass
* Features (20):
  - fct (14): credit_history, employment_duration, foreign_worker, housing, job, other_debtors,
    other_installment_plans, people_liable, personal_status_sex, property, purpose, savings, status,
    telephone
  - int (3): age, amount, duration
  - ord (3): installment_rate, number_credits, present_residence
task$positive # (check the positive class)
[1] "good"
< section id="exercise-fairly-evaluate-the-performance-of-two-learners" class="level1">

Exercise: Fairly evaluate the performance of two learners

We first create two mlr3 learners, a logistic regression and a KNN learner. We then compare their performance via resampling.

< section id="create-the-learners" class="level2">

Create the learners

Create a logistic regression learner (store it as an R object called log_reg) and KNN learner with (store it as an R object called knn).

< details> < summary> Show Hint 1: Check as.data.table(mlr_learners) to find the appropriate learner. < details> < summary> Show Hint 2: Make sure to have the kknn package installed.
Solution
log_reg = lrn("classif.log_reg")
knn = lrn("classif.kknn", k = 5)
< section id="set-up-a-resampling-instance" class="level2">

Set up a resampling instance

Use the mlr3 to set up a resampling instance and store it as an R object called cv5. Here, we aim for 5-fold cross-validation. A table of possible resampling techniques implemented in mlr3 can be shown by looking at as.data.table(mlr_resamplings).

< details> < summary> Show Hint 1: Look at the table returned by as.data.table(mlr_resamplings) and use the rsmp function to set up a 5-fold cross-validation instance. Store the result of the rsmp function in an R object called cv5. < details> < summary> Show Hint 2: rsmp("cv") by default sets up a 10-fold cross-validation instance. The number of folds can be set using an additional argument (see the params column from as.data.table(mlr_resamplings)).
Solution
cv5 = rsmp("cv", folds = 5)
cv5
<ResamplingCV>: Cross-Validation
* Iterations: 5
* Instantiated: FALSE
* Parameters: folds=5

Note: Instantiated: FALSE means that we only created the resampling instance and did not apply the resampling technique to a task yet.

< section id="run-the-resampling" class="level2">

Run the resampling

After having created a resampling instance, use it to apply the chosen resampling technique to both previously created learners.

< details> < summary> Show Hint 1: You need to supply the task, the learner and the previously created resampling instance as arguments to the resample function. See ?resample for further details and examples. < details> < summary> Show Hint 2:

The key ingredients for resample() are a task (created by tsk()), a learner (created by lrn()) and a resampling strategy (created by rsmp()), e.g.,

resample(task = task, learner = log_reg, resampling = cv5)

Solution
res_log_reg = resample(task, log_reg, cv5)
INFO  [14:51:01.363] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 1/5)
INFO  [14:51:04.850] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 2/5)
INFO  [14:51:07.933] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 3/5)
INFO  [14:51:10.935] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 4/5)
INFO  [14:51:13.888] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 5/5)
res_knn = resample(task, knn, cv5)
INFO  [14:51:15.267] [mlr3] Applying learner 'classif.kknn' on task 'german_credit' (iter 1/5)
Warning in model.matrix.default(mt2, test, contrasts.arg = contrasts.arg): Variable 'credit_risk' fehlt, ihre Kontraste
werden ignoriert
INFO  [14:51:16.672] [mlr3] Applying learner 'classif.kknn' on task 'german_credit' (iter 2/5)
Warning in model.matrix.default(mt2, test, contrasts.arg = contrasts.arg): Variable 'credit_risk' fehlt, ihre Kontraste
werden ignoriert
INFO  [14:51:17.940] [mlr3] Applying learner 'classif.kknn' on task 'german_credit' (iter 3/5)
Warning in model.matrix.default(mt2, test, contrasts.arg = contrasts.arg): Variable 'credit_risk' fehlt, ihre Kontraste
werden ignoriert
INFO  [14:51:19.053] [mlr3] Applying learner 'classif.kknn' on task 'german_credit' (iter 4/5)
Warning in model.matrix.default(mt2, test, contrasts.arg = contrasts.arg): Variable 'credit_risk' fehlt, ihre Kontraste
werden ignoriert
INFO  [14:51:20.158] [mlr3] Applying learner 'classif.kknn' on task 'german_credit' (iter 5/5)
Warning in model.matrix.default(mt2, test, contrasts.arg = contrasts.arg): Variable 'credit_risk' fehlt, ihre Kontraste
werden ignoriert
res_log_reg
<ResampleResult> with 5 resampling iterations
       task_id      learner_id resampling_id iteration     prediction_test warnings errors
 german_credit classif.log_reg            cv         1 <PredictionClassif>        0      0
 german_credit classif.log_reg            cv         2 <PredictionClassif>        0      0
 german_credit classif.log_reg            cv         3 <PredictionClassif>        0      0
 german_credit classif.log_reg            cv         4 <PredictionClassif>        0      0
 german_credit classif.log_reg            cv         5 <PredictionClassif>        0      0
res_knn
<ResampleResult> with 5 resampling iterations
       task_id   learner_id resampling_id iteration     prediction_test warnings errors
 german_credit classif.kknn            cv         1 <PredictionClassif>        0      0
 german_credit classif.kknn            cv         2 <PredictionClassif>        0      0
 german_credit classif.kknn            cv         3 <PredictionClassif>        0      0
 german_credit classif.kknn            cv         4 <PredictionClassif>        0      0
 german_credit classif.kknn            cv         5 <PredictionClassif>        0      0
< section id="evaluation" class="level2">

Evaluation

Compute the cross-validated classification accuracy of both models. Which learner performed better?

< details> < summary> Show Hint 1: Use msr("classif.acc") and the aggregate method of the resampling object. < details> < summary> Show Hint 2: res_knn$aggregate(msr(...)) to obtain the classification accuracy averaged across all folds.
Solution
res_knn$aggregate(msr("classif.acc"))
classif.acc 
       0.72 
res_log_reg$aggregate(msr("classif.acc"))
classif.acc 
      0.747 

Note: Use e.g. res_knn$score(msr(...)) to look at the results of each individual fold.

< section id="summary" class="level1">

Summary

We can now apply different resampling methods to estimate the performance of different learners and fairly compare them. We now have learnt how to obtain a better (in terms of variance) estimate of our model performance instead of doing a simple train and test split. This enables us to fairly compare different learners.

To leave a comment for the author, please follow the link and comment on their blog: mlr-org.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Exit mobile version