Explaining Keras image classification models with lime

[This article was first published on Shirin's playgRound, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Last week I published a blog post about how easy it is to train image classification models with Keras.

What I did not show in that post was how to use the model for making predictions. This, I will do here. But predictions alone are boring, so I’m adding explanations for the predictions using the lime package.

I have already written a few blog posts (here, here and here) about LIME and have given talks (here and here) about it, too.

Neither of them applies LIME to image classification models, though. And with the new(ish) release from March of Thomas Lin Pedersen’s lime package, lime is now not only on CRAN but it natively supports Keras and image classification models.

Thomas wrote a very nice article about how to use keras and lime in R! Here, I am following this article to use Imagenet (VGG16) to make and explain predictions of fruit images and then I am extending the analysis to last week’s model and compare it with the pretrained net.

Loading libraries and models

library(keras)   # for working with neural nets
library(lime)    # for explaining models
library(magick)  # for preprocessing images
library(ggplot2) # for additional plotting
  • Loading the pretrained Imagenet model
model <- application_vgg16(weights = "imagenet", include_top = TRUE)
model
## Model
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## input_1 (InputLayer)             (None, 224, 224, 3)           0           
## ___________________________________________________________________________
## block1_conv1 (Conv2D)            (None, 224, 224, 64)          1792        
## ___________________________________________________________________________
## block1_conv2 (Conv2D)            (None, 224, 224, 64)          36928       
## ___________________________________________________________________________
## block1_pool (MaxPooling2D)       (None, 112, 112, 64)          0           
## ___________________________________________________________________________
## block2_conv1 (Conv2D)            (None, 112, 112, 128)         73856       
## ___________________________________________________________________________
## block2_conv2 (Conv2D)            (None, 112, 112, 128)         147584      
## ___________________________________________________________________________
## block2_pool (MaxPooling2D)       (None, 56, 56, 128)           0           
## ___________________________________________________________________________
## block3_conv1 (Conv2D)            (None, 56, 56, 256)           295168      
## ___________________________________________________________________________
## block3_conv2 (Conv2D)            (None, 56, 56, 256)           590080      
## ___________________________________________________________________________
## block3_conv3 (Conv2D)            (None, 56, 56, 256)           590080      
## ___________________________________________________________________________
## block3_pool (MaxPooling2D)       (None, 28, 28, 256)           0           
## ___________________________________________________________________________
## block4_conv1 (Conv2D)            (None, 28, 28, 512)           1180160     
## ___________________________________________________________________________
## block4_conv2 (Conv2D)            (None, 28, 28, 512)           2359808     
## ___________________________________________________________________________
## block4_conv3 (Conv2D)            (None, 28, 28, 512)           2359808     
## ___________________________________________________________________________
## block4_pool (MaxPooling2D)       (None, 14, 14, 512)           0           
## ___________________________________________________________________________
## block5_conv1 (Conv2D)            (None, 14, 14, 512)           2359808     
## ___________________________________________________________________________
## block5_conv2 (Conv2D)            (None, 14, 14, 512)           2359808     
## ___________________________________________________________________________
## block5_conv3 (Conv2D)            (None, 14, 14, 512)           2359808     
## ___________________________________________________________________________
## block5_pool (MaxPooling2D)       (None, 7, 7, 512)             0           
## ___________________________________________________________________________
## flatten (Flatten)                (None, 25088)                 0           
## ___________________________________________________________________________
## fc1 (Dense)                      (None, 4096)                  102764544   
## ___________________________________________________________________________
## fc2 (Dense)                      (None, 4096)                  16781312    
## ___________________________________________________________________________
## predictions (Dense)              (None, 1000)                  4097000     
## ===========================================================================
## Total params: 138,357,544
## Trainable params: 138,357,544
## Non-trainable params: 0
## ___________________________________________________________________________
model2 <- load_model_hdf5(filepath = "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/keras/fruits_checkpoints.h5")
model2
## Model
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## conv2d_1 (Conv2D)                (None, 20, 20, 32)            896         
## ___________________________________________________________________________
## activation_1 (Activation)        (None, 20, 20, 32)            0           
## ___________________________________________________________________________
## conv2d_2 (Conv2D)                (None, 20, 20, 16)            4624        
## ___________________________________________________________________________
## leaky_re_lu_1 (LeakyReLU)        (None, 20, 20, 16)            0           
## ___________________________________________________________________________
## batch_normalization_1 (BatchNorm (None, 20, 20, 16)            64          
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D)   (None, 10, 10, 16)            0           
## ___________________________________________________________________________
## dropout_1 (Dropout)              (None, 10, 10, 16)            0           
## ___________________________________________________________________________
## flatten_1 (Flatten)              (None, 1600)                  0           
## ___________________________________________________________________________
## dense_1 (Dense)                  (None, 100)                   160100      
## ___________________________________________________________________________
## activation_2 (Activation)        (None, 100)                   0           
## ___________________________________________________________________________
## dropout_2 (Dropout)              (None, 100)                   0           
## ___________________________________________________________________________
## dense_2 (Dense)                  (None, 16)                    1616        
## ___________________________________________________________________________
## activation_3 (Activation)        (None, 16)                    0           
## ===========================================================================
## Total params: 167,300
## Trainable params: 167,268
## Non-trainable params: 32
## ___________________________________________________________________________

Load and prepare images

Here, I am loading and preprocessing two images of fruits (and yes, I am cheating a bit because I am choosing images where I expect my model to work as they are similar to the training images…).

  • Banana
test_image_files_path <- "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/Test"

img <- image_read('https://upload.wikimedia.org/wikipedia/commons/thumb/8/8a/Banana-Single.jpg/272px-Banana-Single.jpg')
img_path <- file.path(test_image_files_path, "Banana", 'banana.jpg')
image_write(img, img_path)
#plot(as.raster(img))
  • Clementine
img2 <- image_read('https://cdn.pixabay.com/photo/2010/12/13/09/51/clementine-1792_1280.jpg')
img_path2 <- file.path(test_image_files_path, "Clementine", 'clementine.jpg')
image_write(img2, img_path2)
#plot(as.raster(img2))

Superpixels

The segmentation of an image into superpixels are an important step in generating explanations for image models. It is both important that the segmentation is correct and follows meaningful patterns in the picture, but also that the size/number of superpixels are appropriate. If the important features in the image are chopped into too many segments the permutations will probably damage the picture beyond recognition in almost all cases leading to a poor or failing explanation model. As the size of the object of interest is varying it is impossible to set up hard rules for the number of superpixels to segment into - the larger the object is relative to the size of the image, the fewer superpixels should be generated. Using plot_superpixels it is possible to evaluate the superpixel parameters before starting the time consuming explanation function. (help(plot_superpixels))

plot_superpixels(img_path, n_superpixels = 35, weight = 10)

plot_superpixels(img_path2, n_superpixels = 50, weight = 20)

From the superpixel plots we can see that the clementine image has a higher resolution than the banana image.

Prepare images for Imagenet

image_prep <- function(x) {
  arrays <- lapply(x, function(path) {
    img <- image_load(path, target_size = c(224,224))
    x <- image_to_array(img)
    x <- array_reshape(x, c(1, dim(x)))
    x <- imagenet_preprocess_input(x)
  })
  do.call(abind::abind, c(arrays, list(along = 1)))
}
  • test predictions
res <- predict(model, image_prep(c(img_path, img_path2)))
imagenet_decode_predictions(res)
## [[1]]
##   class_name class_description        score
## 1  n07753592            banana 0.9929747581
## 2  n03532672              hook 0.0013420776
## 3  n07747607            orange 0.0010816186
## 4  n07749582             lemon 0.0010625814
## 5  n07716906  spaghetti_squash 0.0009176208
## 
## [[2]]
##   class_name class_description      score
## 1  n07747607            orange 0.78233224
## 2  n07753592            banana 0.04653566
## 3  n07749582             lemon 0.03868873
## 4  n03134739      croquet_ball 0.03350329
## 5  n07745940        strawberry 0.01862431
  • load labels and train explainer
model_labels <- readRDS(system.file('extdata', 'imagenet_labels.rds', package = 'lime'))
explainer <- lime(c(img_path, img_path2), as_classifier(model, model_labels), image_prep)

Training the explainer (explain() function) can take pretty long. It will be much faster with the smaller images in my own model but with the bigger Imagenet it takes a few minutes to run.

explanation <- explain(c(img_path, img_path2), explainer, 
                       n_labels = 2, n_features = 35,
                       n_superpixels = 35, weight = 10,
                       background = "white")
  • plot_image_explanation() only supports showing one case at a time
plot_image_explanation(explanation)

clementine <- explanation[explanation$case == "clementine.jpg",]
plot_image_explanation(clementine)

Prepare images for my own model

  • test predictions (analogous to training and validation images)
test_datagen <- image_data_generator(rescale = 1/255)

test_generator = flow_images_from_directory(
        test_image_files_path,
        test_datagen,
        target_size = c(20, 20),
        class_mode = 'categorical')

predictions <- as.data.frame(predict_generator(model2, test_generator, steps = 1))

load("/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/fruits_classes_indices.RData")
fruits_classes_indices_df <- data.frame(indices = unlist(fruits_classes_indices))
fruits_classes_indices_df <- fruits_classes_indices_df[order(fruits_classes_indices_df$indices), , drop = FALSE]
colnames(predictions) <- rownames(fruits_classes_indices_df)

t(round(predictions, digits = 2))
##             [,1] [,2]
## Kiwi           0 0.00
## Banana         1 0.11
## Apricot        0 0.00
## Avocado        0 0.00
## Cocos          0 0.00
## Clementine     0 0.87
## Mandarine      0 0.00
## Orange         0 0.00
## Limes          0 0.00
## Lemon          0 0.00
## Peach          0 0.00
## Plum           0 0.00
## Raspberry      0 0.00
## Strawberry     0 0.01
## Pineapple      0 0.00
## Pomegranate    0 0.00
for (i in 1:nrow(predictions)) {
  cat(i, ":")
  print(unlist(which.max(predictions[i, ])))
}
## 1 :Banana 
##      2 
## 2 :Clementine 
##          6

This seems to be incompatible with lime, though (or if someone knows how it works, please let me know) - so I prepared the images similarly to the Imagenet images.

image_prep2 <- function(x) {
  arrays <- lapply(x, function(path) {
    img <- image_load(path, target_size = c(20, 20))
    x <- image_to_array(img)
    x <- reticulate::array_reshape(x, c(1, dim(x)))
    x <- x / 255
  })
  do.call(abind::abind, c(arrays, list(along = 1)))
}
  • prepare labels
fruits_classes_indices_l <- rownames(fruits_classes_indices_df)
names(fruits_classes_indices_l) <- unlist(fruits_classes_indices)
fruits_classes_indices_l
##             9            10             8             2            11 
##        "Kiwi"      "Banana"     "Apricot"     "Avocado"       "Cocos" 
##             3            13            14             7             6 
##  "Clementine"   "Mandarine"      "Orange"       "Limes"       "Lemon" 
##             1             5             0             4            15 
##       "Peach"        "Plum"   "Raspberry"  "Strawberry"   "Pineapple" 
##            12 
## "Pomegranate"
  • train explainer
explainer2 <- lime(c(img_path, img_path2), as_classifier(model2, fruits_classes_indices_l), image_prep2)
explanation2 <- explain(c(img_path, img_path2), explainer2, 
                        n_labels = 1, n_features = 20,
                        n_superpixels = 35, weight = 10,
                        background = "white")
  • plot feature weights to find a good threshold for plotting block (see below)
explanation2 %>%
  ggplot(aes(x = feature_weight)) +
    facet_wrap(~ case, scales = "free") +
    geom_density()

  • plot predictions
plot_image_explanation(explanation2, display = 'block', threshold = 5e-07)

clementine2 <- explanation2[explanation2$case == "clementine.jpg",]
plot_image_explanation(clementine2, display = 'block', threshold = 0.16)


sessionInfo()
## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS High Sierra 10.13.5
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] ggplot2_2.2.1 magick_1.9    lime_0.4.0    keras_2.1.6  
## 
## loaded via a namespace (and not attached):
##  [1] stringdist_0.9.5.1 reticulate_1.8     xfun_0.2          
##  [4] lattice_0.20-35    colorspace_1.3-2   htmltools_0.3.6   
##  [7] yaml_2.1.19        base64enc_0.1-3    rlang_0.2.1       
## [10] pillar_1.2.3       later_0.7.3        foreach_1.4.4     
## [13] plyr_1.8.4         tensorflow_1.8     stringr_1.3.1     
## [16] munsell_0.5.0      blogdown_0.6       gtable_0.2.0      
## [19] htmlwidgets_1.2    codetools_0.2-15   evaluate_0.10.1   
## [22] labeling_0.3       knitr_1.20         httpuv_1.4.4.1    
## [25] tfruns_1.3         parallel_3.5.0     curl_3.2          
## [28] Rcpp_0.12.17       xtable_1.8-2       scales_0.5.0      
## [31] backports_1.1.2    promises_1.0.1     jsonlite_1.5      
## [34] abind_1.4-5        mime_0.5           digest_0.6.15     
## [37] stringi_1.2.3      bookdown_0.7       shiny_1.1.0       
## [40] grid_3.5.0         rprojroot_1.3-2    tools_3.5.0       
## [43] magrittr_1.5       lazyeval_0.2.1     shinythemes_1.1.1 
## [46] glmnet_2.0-16      tibble_1.4.2       whisker_0.3-2     
## [49] zeallot_0.1.0      Matrix_1.2-14      gower_0.1.2       
## [52] assertthat_0.2.0   rmarkdown_1.10     iterators_1.0.9   
## [55] R6_2.2.2           compiler_3.5.0

To leave a comment for the author, please follow the link and comment on their blog: Shirin's playgRound.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)