Explaining Predictions of Machine Learning Models with LIME – Münster Data Science Meetup

[This article was first published on Shirin's playgRound, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Slides from Münster Data Science Meetup

These are my slides from the Münster Data Science Meetup on December 12th, 2017.


My sketchnotes were collected from these two podcasts:

Sketchnotes: TWiML Talk #7 with Carlos Guestrin – Explaining the Predictions of Machine Learning Models & Data Skeptic Podcast - Trusting Machine Learning Models with Lime

Sketchnotes: TWiML Talk #7 with Carlos Guestrin – Explaining the Predictions of Machine Learning Models & Data Skeptic Podcast – Trusting Machine Learning Models with Lime


Example Code

  • the following libraries were loaded:
library(tidyverse)  # for tidy data analysis
library(farff)      # for reading arff file
library(missForest) # for imputing missing values
library(dummies)    # for creating dummy variables
library(caret)      # for modeling
library(lime)       # for explaining predictions

Data

The Chronic Kidney Disease dataset was downloaded from UC Irvine’s Machine Learning repository: http://archive.ics.uci.edu/ml/datasets/Chronic_Kidney_Disease

data_file <- file.path("path/to/chronic_kidney_disease_full.arff")
  • load data with the farff package
data <- readARFF(data_file)

Features

  • age - age
  • bp - blood pressure
  • sg - specific gravity
  • al - albumin
  • su - sugar
  • rbc - red blood cells
  • pc - pus cell
  • pcc - pus cell clumps
  • ba - bacteria
  • bgr - blood glucose random
  • bu - blood urea
  • sc - serum creatinine
  • sod - sodium
  • pot - potassium
  • hemo - hemoglobin
  • pcv - packed cell volume
  • wc - white blood cell count
  • rc - red blood cell count
  • htn - hypertension
  • dm - diabetes mellitus
  • cad - coronary artery disease
  • appet - appetite
  • pe - pedal edema
  • ane - anemia
  • class - class

Missing data

  • impute missing data with Nonparametric Missing Value Imputation using Random Forest (missForest package)
data_imp <- missForest(data)

One-hot encoding

  • create dummy variables (caret::dummy.data.frame())
  • scale and center
data_imp_final <- data_imp$ximp
data_dummy <- dummy.data.frame(dplyr::select(data_imp_final, -class), sep = "_")
data <- cbind(dplyr::select(data_imp_final, class), scale(data_dummy, 
                                                   center = apply(data_dummy, 2, min),
                                                   scale = apply(data_dummy, 2, max)))

Modeling

# training and test set
set.seed(42)
index <- createDataPartition(data$class, p = 0.9, list = FALSE)
train_data <- data[index, ]
test_data  <- data[-index, ]

# modeling
model_rf <- caret::train(class ~ .,
  data = train_data,
  method = "rf", # random forest
  trControl = trainControl(method = "repeatedcv", 
       number = 10, 
       repeats = 5, 
       verboseIter = FALSE))
model_rf
## Random Forest 
## 
## 360 samples
##  48 predictor
##   2 classes: 'ckd', 'notckd' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 5 times) 
## Summary of sample sizes: 324, 324, 324, 324, 325, 324, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.9922647  0.9838466
##   25    0.9917392  0.9826070
##   48    0.9872930  0.9729881
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.
# predictions
pred <- data.frame(sample_id = 1:nrow(test_data), predict(model_rf, test_data, type = "prob"), actual = test_data$class) %>%
  mutate(prediction = colnames(.)[2:3][apply(.[, 2:3], 1, which.max)], correct = ifelse(actual == prediction, "correct", "wrong"))

confusionMatrix(pred$actual, pred$prediction)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction ckd notckd
##     ckd     23      2
##     notckd   0     15
##                                           
##                Accuracy : 0.95            
##                  95% CI : (0.8308, 0.9939)
##     No Information Rate : 0.575           
##     P-Value [Acc > NIR] : 1.113e-07       
##                                           
##                   Kappa : 0.8961          
##  Mcnemar's Test P-Value : 0.4795          
##                                           
##             Sensitivity : 1.0000          
##             Specificity : 0.8824          
##          Pos Pred Value : 0.9200          
##          Neg Pred Value : 1.0000          
##              Prevalence : 0.5750          
##          Detection Rate : 0.5750          
##    Detection Prevalence : 0.6250          
##       Balanced Accuracy : 0.9412          
##                                           
##        'Positive' Class : ckd             
## 

LIME

  • LIME needs data without response variable
train_x <- dplyr::select(train_data, -class)
test_x <- dplyr::select(test_data, -class)

train_y <- dplyr::select(train_data, class)
test_y <- dplyr::select(test_data, class)
  • build explainer
explainer <- lime(train_x, model_rf, n_bins = 5, quantile_bins = TRUE)
  • run explain() function
explanation_df <- lime::explain(test_x, explainer, n_labels = 1, n_features = 8, n_permutations = 1000, feature_select = "forward_selection")
  • model reliability
explanation_df %>%
  ggplot(aes(x = model_r2, fill = label)) +
    geom_density(alpha = 0.5)

  • plot explanations
plot_features(explanation_df[1:24, ], ncol = 1)

Session Info

## Session info -------------------------------------------------------------
##  setting  value                       
##  version  R version 3.4.2 (2017-09-28)
##  system   x86_64, darwin15.6.0        
##  ui       X11                         
##  language (EN)                        
##  collate  de_DE.UTF-8                 
##  tz       <NA>                        
##  date     2017-12-12
## Packages -----------------------------------------------------------------
##  package      * version  date       source        
##  assertthat     0.2.0    2017-04-11 CRAN (R 3.4.0)
##  backports      1.1.1    2017-09-25 CRAN (R 3.4.2)
##  base         * 3.4.2    2017-10-04 local         
##  BBmisc         1.11     2017-03-10 CRAN (R 3.4.0)
##  bindr          0.1      2016-11-13 CRAN (R 3.4.0)
##  bindrcpp     * 0.2      2017-06-17 CRAN (R 3.4.0)
##  blogdown       0.3      2017-11-13 CRAN (R 3.4.2)
##  bookdown       0.5      2017-08-20 CRAN (R 3.4.1)
##  broom          0.4.3    2017-11-20 CRAN (R 3.4.2)
##  caret        * 6.0-77   2017-09-07 CRAN (R 3.4.1)
##  cellranger     1.1.0    2016-07-27 CRAN (R 3.4.0)
##  checkmate      1.8.5    2017-10-24 CRAN (R 3.4.2)
##  class          7.3-14   2015-08-30 CRAN (R 3.4.2)
##  cli            1.0.0    2017-11-05 CRAN (R 3.4.2)
##  codetools      0.2-15   2016-10-05 CRAN (R 3.4.2)
##  colorspace     1.3-2    2016-12-14 CRAN (R 3.4.0)
##  compiler       3.4.2    2017-10-04 local         
##  crayon         1.3.4    2017-09-16 cran (@1.3.4) 
##  CVST           0.2-1    2013-12-10 CRAN (R 3.4.0)
##  datasets     * 3.4.2    2017-10-04 local         
##  ddalpha        1.3.1    2017-09-27 CRAN (R 3.4.2)
##  DEoptimR       1.0-8    2016-11-19 CRAN (R 3.4.0)
##  devtools       1.13.4   2017-11-09 CRAN (R 3.4.2)
##  digest         0.6.12   2017-01-27 CRAN (R 3.4.0)
##  dimRed         0.1.0    2017-05-04 CRAN (R 3.4.0)
##  dplyr        * 0.7.4    2017-09-28 CRAN (R 3.4.2)
##  DRR            0.0.2    2016-09-15 CRAN (R 3.4.0)
##  dummies      * 1.5.6    2012-06-14 CRAN (R 3.4.0)
##  e1071          1.6-8    2017-02-02 CRAN (R 3.4.0)
##  evaluate       0.10.1   2017-06-24 CRAN (R 3.4.0)
##  farff        * 1.0      2016-09-11 CRAN (R 3.4.0)
##  forcats      * 0.2.0    2017-01-23 CRAN (R 3.4.0)
##  foreach      * 1.4.3    2015-10-13 CRAN (R 3.4.0)
##  foreign        0.8-69   2017-06-22 CRAN (R 3.4.1)
##  ggplot2      * 2.2.1    2016-12-30 CRAN (R 3.4.0)
##  glmnet         2.0-13   2017-09-22 CRAN (R 3.4.2)
##  glue           1.2.0    2017-10-29 CRAN (R 3.4.2)
##  gower          0.1.2    2017-02-23 CRAN (R 3.4.0)
##  graphics     * 3.4.2    2017-10-04 local         
##  grDevices    * 3.4.2    2017-10-04 local         
##  grid           3.4.2    2017-10-04 local         
##  gtable         0.2.0    2016-02-26 CRAN (R 3.4.0)
##  haven          1.1.0    2017-07-09 CRAN (R 3.4.0)
##  hms            0.4.0    2017-11-23 CRAN (R 3.4.3)
##  htmltools      0.3.6    2017-04-28 CRAN (R 3.4.0)
##  htmlwidgets    0.9      2017-07-10 CRAN (R 3.4.1)
##  httpuv         1.3.5    2017-07-04 CRAN (R 3.4.1)
##  httr           1.3.1    2017-08-20 CRAN (R 3.4.1)
##  ipred          0.9-6    2017-03-01 CRAN (R 3.4.0)
##  iterators    * 1.0.8    2015-10-13 CRAN (R 3.4.0)
##  itertools    * 0.1-3    2014-03-12 CRAN (R 3.4.0)
##  jsonlite       1.5      2017-06-01 CRAN (R 3.4.0)
##  kernlab        0.9-25   2016-10-03 CRAN (R 3.4.0)
##  knitr          1.17     2017-08-10 CRAN (R 3.4.1)
##  labeling       0.3      2014-08-23 CRAN (R 3.4.0)
##  lattice      * 0.20-35  2017-03-25 CRAN (R 3.4.2)
##  lava           1.5.1    2017-09-27 CRAN (R 3.4.1)
##  lazyeval       0.2.1    2017-10-29 CRAN (R 3.4.2)
##  lime         * 0.3.1    2017-11-24 CRAN (R 3.4.3)
##  lubridate      1.7.1    2017-11-03 CRAN (R 3.4.2)
##  magrittr       1.5      2014-11-22 CRAN (R 3.4.0)
##  MASS           7.3-47   2017-02-26 CRAN (R 3.4.2)
##  Matrix         1.2-12   2017-11-15 CRAN (R 3.4.2)
##  memoise        1.1.0    2017-04-21 CRAN (R 3.4.0)
##  methods      * 3.4.2    2017-10-04 local         
##  mime           0.5      2016-07-07 CRAN (R 3.4.0)
##  missForest   * 1.4      2013-12-31 CRAN (R 3.4.0)
##  mnormt         1.5-5    2016-10-15 CRAN (R 3.4.0)
##  ModelMetrics   1.1.0    2016-08-26 CRAN (R 3.4.0)
##  modelr         0.1.1    2017-07-24 CRAN (R 3.4.1)
##  munsell        0.4.3    2016-02-13 CRAN (R 3.4.0)
##  nlme           3.1-131  2017-02-06 CRAN (R 3.4.2)
##  nnet           7.3-12   2016-02-02 CRAN (R 3.4.2)
##  parallel       3.4.2    2017-10-04 local         
##  pkgconfig      2.0.1    2017-03-21 CRAN (R 3.4.0)
##  plyr           1.8.4    2016-06-08 CRAN (R 3.4.0)
##  prodlim        1.6.1    2017-03-06 CRAN (R 3.4.0)
##  psych          1.7.8    2017-09-09 CRAN (R 3.4.1)
##  purrr        * 0.2.4    2017-10-18 CRAN (R 3.4.2)
##  R6             2.2.2    2017-06-17 CRAN (R 3.4.0)
##  randomForest * 4.6-12   2015-10-07 CRAN (R 3.4.0)
##  Rcpp           0.12.14  2017-11-23 CRAN (R 3.4.3)
##  RcppRoll       0.2.2    2015-04-05 CRAN (R 3.4.0)
##  readr        * 1.1.1    2017-05-16 CRAN (R 3.4.0)
##  readxl         1.0.0    2017-04-18 CRAN (R 3.4.0)
##  recipes        0.1.1    2017-11-20 CRAN (R 3.4.3)
##  reshape2       1.4.2    2016-10-22 CRAN (R 3.4.0)
##  rlang          0.1.4    2017-11-05 CRAN (R 3.4.2)
##  rmarkdown      1.8      2017-11-17 CRAN (R 3.4.2)
##  robustbase     0.92-8   2017-11-01 CRAN (R 3.4.2)
##  rpart          4.1-11   2017-03-13 CRAN (R 3.4.2)
##  rprojroot      1.2      2017-01-16 CRAN (R 3.4.0)
##  rstudioapi     0.7      2017-09-07 CRAN (R 3.4.1)
##  rvest          0.3.2    2016-06-17 CRAN (R 3.4.0)
##  scales         0.5.0    2017-08-24 CRAN (R 3.4.1)
##  sfsmisc        1.1-1    2017-06-08 CRAN (R 3.4.0)
##  shiny          1.0.5    2017-08-23 CRAN (R 3.4.1)
##  shinythemes    1.1.1    2016-10-12 CRAN (R 3.4.0)
##  splines        3.4.2    2017-10-04 local         
##  stats        * 3.4.2    2017-10-04 local         
##  stats4         3.4.2    2017-10-04 local         
##  stringdist     0.9.4.6  2017-07-31 CRAN (R 3.4.1)
##  stringi        1.1.6    2017-11-17 CRAN (R 3.4.2)
##  stringr      * 1.2.0    2017-02-18 CRAN (R 3.4.0)
##  survival       2.41-3   2017-04-04 CRAN (R 3.4.0)
##  tibble       * 1.3.4    2017-08-22 CRAN (R 3.4.1)
##  tidyr        * 0.7.2    2017-10-16 CRAN (R 3.4.2)
##  tidyselect     0.2.3    2017-11-06 CRAN (R 3.4.2)
##  tidyverse    * 1.2.1    2017-11-14 CRAN (R 3.4.2)
##  timeDate       3042.101 2017-11-16 CRAN (R 3.4.2)
##  tools          3.4.2    2017-10-04 local         
##  utils        * 3.4.2    2017-10-04 local         
##  withr          2.1.0    2017-11-01 CRAN (R 3.4.2)
##  xml2           1.1.1    2017-01-24 CRAN (R 3.4.0)
##  xtable         1.8-2    2016-02-05 CRAN (R 3.4.0)
##  yaml           2.1.15   2017-12-01 CRAN (R 3.4.3)

To leave a comment for the author, please follow the link and comment on their blog: Shirin's playgRound.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)