Site icon
R-bloggers

Tree Methods

[This article was first published on mlr-org, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
< section id="goal" class="level1">

Goal

The goal for this exercise is to familiarize yourself with two very important machine learning methods, the decision tree and random forest. After this exercise, you should be able to train these models and extract important information to understand the model internals.

< section id="exercises" class="level1">

Exercises

< section id="fit-a-decision-tree" class="level2">

Fit a decision tree

Use task = tsk("german_credit") to create the classification task for the german_credit data and create a decision tree learner (e.g., a CART learner). Train the decision tree on the german_credit classification task. Look at the output of the trained decision tree (you have to access the raw model object).

< details> < summary> Hint 1: The learner we are focusing on here is a decision tree implemented in rpart. The corresponding mlr3 learner key is "classif.rpart". For this exercise, we use the learner with the default hyperparameters. The raw model object can be accessed from the $model slot of the trained learner. < details> < summary> Hint 2:
library(mlr3)
task = tsk(...)
lrn_rpart = lrn(...) # create the learner
lrn_rpart$train(...) # train the learner on the task
lrn_rpart$... # access the raw model object that was fitted
< section id="visualize-the-tree-structure" class="level2">

Visualize the tree structure

To interpret the model and to gain more information about the decision making of predictions, we decide to take a closer look at the decision tree structure by visualizing it.

< details> < summary> Hint 1:

See code example in the help page ?rpart::plot.rpart which shows how to use the plot and text function to the rpart model object. Note that different packages exist to plot the decision tree structure in a visually more appealing way:

< details> < summary> Hint 2:
library("rpart")
...(lrn_rpart$...)
text(lrn_rpart$...)

# Alternative using e.g. the rpart.plot package
library("rpart.plot")
...(lrn_rpart$...)
< section id="fit-a-random-forest" class="level2">

Fit a random forest

To get a more powerful learner we decide to also fit a random forest. Therefore, fit a random forest with default hyperparameters to the german_credit task.

< details> < summary> Reminder

One of the drawbacks of using trees is the instability of the predictor. Small changes in the data may lead to a very different model and therefore a high variance of the predictions. The random forest takes advantages of that and reduces the variance by applying bagging to decision trees.

< details> < summary> Hint 1:

Use the mlr3 learner classif.ranger which uses the ranger implementation to train a random forest.

< details> < summary> Hint 2:
library(mlr3)
library(mlr3learners)

lrn_ranger = lrn(...) # create the learner
lrn_ranger$...(...) # train the learner on the task
< section id="roc-analysis" class="level2">

ROC Analysis

The bank wants to use a tree-based model to predict the credit risk. Conduct a simple benchmark to assess if a decision tree or a random forest works better for these purposes. Specifically, the bank wants that among credit applications the system predicts to be “good”, it can expect at most 10% to be “bad”. Simultaneously, the bank aims at correctly classifying 90% or more of all applications that are “good”. Visualize the benchmark results in a way that helps answer this question. Can the bank expect the model to fulfil their requirements? Which model performs better?

< details> < summary> Hint 1: A benchmark requires three arguments: a task, a list of learners, and a resampling object.
< section id="understand-hyperparameters" class="level2">

Understand hyperparameters

Use task = tsk("german_credit") to create the classification task for the german_credit data. In this exercise, we want to fit decision trees and random forests with different hyperparameters (which can have a significant impact on the performance). Each learner implemented in R (e.g. ranger or rpart) has a lot of control settings that directly influence the model fitting (the so-called hyperparameters). Here, we will consdider the hyperparameters mtry for the ranger learner and maxdepth for the rpart learner.

Your task is to manually create a list containing multiple rpart and ranger learners with different hyperparameter values (e.g., try out increasing maxdepth values for rpart). In the next step, we will use this list to see how the model performance changes for different hyperparameter values.

< details> < summary> Hint 1:

The learners we are focusing on here is a decision tree implemented in rpart and a random forest implemented in ranger. The corresponding mlr3 learner key is "classif.rpart" and "classif.ranger". In mlr3, we can get an overview about all hyperparameters in the $param_set slot. With a mlr3 learner it is possible to get help about the underlying method by using the $help() method (e.g. ?lrn_ranger$help()):

lrn("classif.rpart")$help()
lrn("classif.ranger")$help()
If you are looking for a short description of the meaning of a hyperparameter, you need to look at the help page of the corresponding package that implements the learner, e.g. ?rpart::rpart.control and ?ranger::ranger. < details> < summary> Hint 2:

The possible choices for the hyperparameters can also be viewed with $param_set. Setting the hyperparameters can be done directly in the lrn() call:

# Define a list of learners for the benchmark:
lrns = list(
  lrn("classif.rpart", ...),
  lrn("classif.rpart", ...),
  lrn("classif.rpart", ...),
  lrn("classif.ranger", ...),
  lrn("classif.ranger", ...),
  lrn("classif.ranger", ...))
< section id="comparison-of-trees-and-random-forests" class="level2">

Comparison of trees and random forests

Does it make a difference w.r.t. model performance if we use different hyperparameters? Use the learners from the previous exercise and compare them in a benchmark. Use 5-fold cross-validation as resampling technique and the classification error as performance measure. Visualize the results of the benchmark.

< details> < summary> Hint 1: The function to conduct the benchmark is benchmark and requires to define the resampling with rsmp and the benchmark grid with benchmark_grid. < details> < summary> Hint 2:
set.seed(31415L)

lrns = list(
  lrn("classif.rpart", maxdepth = 1),
  lrn("classif.rpart", maxdepth = 5),
  lrn("classif.rpart", maxdepth = 20),
  lrn("classif.ranger", mtry.ratio = 0.2),
  lrn("classif.ranger", mtry.ratio = 0.5),
  lrn("classif.ranger", mtry.ratio = 0.8))

cv5 = rsmp(..., folds = ...)
cv5$instantiate(...)

bmr = ...(...(task, lrns, cv5))

mlr3viz::autoplot(bmr, measure = msr("classif.ce"))
< section id="summary" class="level1">

Summary

< section id="further-information" class="level1">

Further information

Tree implementations: One of the longest paragraphs in the CRAN Task View about Machine Learning and Statistical Learning gives an overview of existing tree implementations:

“[…] Tree-structured models for regression, classification and survival analysis, following the ideas in the CART book, are implemented in rpart (shipped with base R) and tree. Package rpart is recommended for computing CART-like trees. A rich toolbox of partitioning algorithms is available in Weka, package RWeka provides an interface to this implementation, including the J4.8-variant of C4.5 and M5. The Cubist package fits rule-based models (similar to trees) with linear regression models in the terminal leaves, instance-based corrections and boosting. The C50 package can fit C5.0 classification trees, rule-based models, and boosted versions of these. pre can fit rule-based models for a wider range of response variable types. […]”

To leave a comment for the author, please follow the link and comment on their blog: mlr-org.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Exit mobile version