Kernel SHAP

[This article was first published on R – Michael's and Christian's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Our last posts were on SHAP, one of the major ways to shed light into black-box Machine Learning models. SHAP values decompose predictions in a fair way into additive contributions from each feature. Decomposing many predictions and then analyzing the SHAP values gives a relatively quick and informative picture of the fitted model at hand.

In their 2017 paper on SHAP, Scott Lundberg and Su-In Lee presented Kernel SHAP, an algorithm to calculate SHAP values for any model with numeric predictions. Compared to Monte-Carlo sampling (e.g. implemented in R package “fastshap”), Kernel SHAP is much more efficient.

I had one problem with Kernel SHAP: I never really understood how it works!

Then I found this article by Covert and Lee (2021). The article not only explains all the details of Kernel SHAP, it also offers an version that would iterate until convergence. As a by-product, standard errors of the SHAP values can be calculated on the fly.

This article motivated me to implement the “kernelshap” package in R, complementing “shapr” that uses a different logic.

The new “kernelshap” package in R

The interface is quite simple: You need to pass three things to its main function kernelshap():

  • X: matrix/data.frame/tibble/data.table of observations to explain. Each column is a feature.
  • pred_fun: function that takes an object like X and provides one number per row.
  • bg_X: matrix/data.frame/tibble/data.table representing the background dataset used to calculate marginal expectation. Typically, between 100 and 200 rows.

Example

We will use Keras to build a deep learning model with 631 parameters on diamonds data. Then we decompose 500 predictions with kernelshap() and visualize them with “shapviz”.

We will fit a Gamma regression with log link the four “C” features:

  • carat
  • color
  • clarity
  • cut
library(tidyverse)
library(keras)

# Response and covariates
y <- as.numeric(diamonds$price)
X <- scale(data.matrix(diamonds[c("carat", "color", "cut", "clarity")]))

# Input layer: we have 4 covariates
input <- layer_input(shape = 4)

# Two hidden layers with contracting number of nodes
output <- input %>%
  layer_dense(units = 30, activation = "tanh") %>% 
  layer_dense(units = 15, activation = "tanh") %>% 
  layer_dense(units = 1, activation = k_exp)

# Create and compile model
nn <- keras_model(inputs = input, outputs = output)
summary(nn)

# Gamma regression loss
loss_gamma <- function(y_true, y_pred) {
  -k_log(y_true / y_pred) + y_true / y_pred
}

nn %>% 
  compile(
    optimizer = optimizer_adam(learning_rate = 0.001),
    loss = loss_gamma
  )

# Callbacks
cb <- list(
  callback_early_stopping(patience = 20),
  callback_reduce_lr_on_plateau(patience = 5)
)

# Fit model
history <- nn %>% 
  fit(
    x = X,
    y = y,
    epochs = 100,
    batch_size = 400, 
    validation_split = 0.2,
    callbacks = cb
  )

history$metrics[c("loss", "val_loss")] %>% 
  data.frame() %>% 
  mutate(epoch = row_number()) %>% 
  filter(epoch >= 3) %>% 
  pivot_longer(cols = c("loss", "val_loss")) %>% 
ggplot(aes(x = epoch, y = value, group = name, color = name)) +
  geom_line(size = 1.4)

Interpretation via KernelSHAP

In order to peak into the fitted model, we apply the Kernel SHAP algorithm to decompose 500 randomly selected diamond predictions. We use the same subset as background dataset required by the Kernel SHAP algorithm.

Afterwards, we will study

  • Some SHAP values and their standard errors
  • One waterfall plot
  • A beeswarm summary plot to get a rough picture of variable importance and the direction of the feature effects
  • A SHAP dependence plot for carat
# Interpretation on 500 randomly selected diamonds
library(kernelshap)
library(shapviz)

sample(1)
ind <- sample(nrow(X), 500)

dia_small <- X[ind, ]

# 77 seconds
system.time(
  ks <- kernelshap(
    dia_small, 
    pred_fun = function(X) as.numeric(predict(nn, X, batch_size = nrow(X))), 
    bg_X = dia_small
  )
)
ks

# Output
# 'kernelshap' object representing 
# - SHAP matrix of dimension 500 x 4 
# - feature data.frame/matrix of dimension 500 x 4 
# - baseline value of 3744.153
# 
# SHAP values of first 2 observations:
#         carat     color       cut   clarity
# [1,] -110.738 -240.2758  5.254733 -720.3610
# [2,] 2379.065  263.3112 56.413680  452.3044
# 
# Corresponding standard errors:
#         carat      color       cut  clarity
# [1,] 2.064393 0.05113337 0.1374942 2.150754
# [2,] 2.614281 0.84934844 0.9373701 0.827563

sv <- shapviz(ks, X = diamonds[ind, x])
sv_waterfall(sv, 1)
sv_importance(sv, "both")
sv_dependence(sv, "carat", "auto")

Note the small standard errors of the SHAP values of the first two diamonds. They are only approximate because the background data is only a sample from an unknown population. Still, they give a good impression on the stability of the results.

The waterfall plot shows a diamond with not super nice clarity and color, pulling down the value of this diamond. Note that, even if the model is working with scaled numeric feature values, the plot shows the original feature values.

SHAP waterfall plot of one diamond. Note its bad clarity.

The SHAP summary plot shows that "carat" is, unsurprisingly, the most important variable and that high carat mean high value. "cut" is not very important, except if it is extremely bad.

SHAP summary plot with bars representing average absolute values as measure of importance.

Our last plot is a SHAP dependence plot for "carat": the effect makes sense, and we can spot some interaction with color. For worse colors (H-J), the effect of carat is a bit less strong as for the very white diamonds.

Dependence plot for "carat"

Short wrap-up

  • Standard Kernel SHAP in R, yeahhhhh 🙂
  • The Github version is relatively fast, so you can even decompose 500 observations of a deep learning model within 1-2 minutes.

The complete R script can be found here.

To leave a comment for the author, please follow the link and comment on their blog: R – Michael's and Christian's Blog.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)