SHAP + XGBoost + Tidymodels = LOVE

[This article was first published on R – Michael's and Christian's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

In this recent post, we have explained how to use Kernel SHAP for interpreting complex linear models. As plotting backend, we used our fresh CRAN package “shapviz“.

“shapviz” has direct connectors to a couple of packages such as XGBoost, LightGBM, H2O, kernelshap, and more. Multiple times people asked me how to combine shapviz when the XGBoost model was fitted with Tidymodels. The workflow was not 100% clear to me as well, but the answer is actually very simple, thanks to Julia’s post where the plots were made with SHAPforxgboost, another cool package for visualization of SHAP values.

Example with shiny diamonds

Step 1: Preprocessing

We first write the data preprocessing recipe and apply it to the data rows that we want to explain. In our case, its 1000 randomly sampled diamonds.

library(tidyverse)
library(tidymodels)
library(shapviz)

# Integer encode factors
dia_recipe <- diamonds %>%
  recipe(price ~ carat + cut + clarity + color) %>% 
  step_integer(all_nominal())

# Will explain THIS dataset later
set.seed(2)
dia_small <- diamonds[sample(nrow(diamonds), 1000), ]
dia_small_prep <- bake(
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = dia_small, 
  composition = "matrix"
)
head(dia_small_prep)

#     carat cut clarity color
#[1,]  0.57   5       4     4
#[2,]  1.01   5       2     1
#[3,]  0.45   1       4     3
#[4,]  1.04   4       6     5
#[5,]  0.90   3       6     4
#[6,]  1.20   3       4     6

Step 2: Fit Model

The next step is to tune and build the model. For simplicity, we skipped the tuning part. Bad, bad 🙂

# Just for illustration - in practice needs tuning!
xgboost_model <- boost_tree(
  mode = "regression",
  trees = 200,
  tree_depth = 5,
  learn_rate = 0.05,
  engine = "xgboost"
)

dia_wf <- workflow() %>%
  add_recipe(dia_recipe) %>%
  add_model(xgboost_model)

fit <- dia_wf %>%
  fit(diamonds)

Step 3: SHAP Analysis

We now need to call shapviz() on the fitted model. In order to have neat interpretations with the original factor labels, we not only pass the prediction data prepared in Step 1 via bake(), but also the original data structure.

shap <- shapviz(extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small)

sv_importance(shap, kind = "both", show_numbers = TRUE)
sv_dependence(shap, "carat", color_var = "auto")
sv_dependence(shap, "clarity", color_var = "auto")
sv_force(shap, row_id = 1)
sv_waterfall(shap, row_id = 1)
Variable importance plot overlaid with SHAP summary beeswarms
Dependence plot for carat. Note that clarity is shown with original labels, not only integers.
Dependence plot for clarity. Note again that the x-scale uses the original factor levels, not the integer encoded values.
Force plot of the first observation
Waterfall plot for the first observation

Summary

Making SHAP analyses with XGBoost Tidymodels is super easy.

The complete R script can be found here.

To leave a comment for the author, please follow the link and comment on their blog: R – Michael's and Christian's Blog.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)