It’s that easy! Image classification with keras in roughly 100 lines of code.

June 14, 2018
By

(This article was first published on Shirin's playgRound, and kindly contributed to R-bloggers)

I’ve been using keras and TensorFlow for a while now – and love its simplicity and straight-forward way to modeling. As part of the latest update to my Workshop about deep learning with R and keras I’ve added a new example analysis:

Building an image classifier to differentiate different types of fruits

And I was (again) suprised how fast and easy it was to build the model; it took not even half an hour and only around 100 lines of code (counting only the main code; for this post I added comments and line breaks to make it easier to read)!

via GIPHY

That’s why I wanted to share it here and spread the keras love. <3

The code

If you haven’t installed keras before, follow the instructions of RStudio’s keras site

library(keras)

The dataset is the fruit images dataset from Kaggle. I downloaded it to my computer and unpacked it. Because I don’t want to build a model for all the different fruits, I define a list of fruits (corresponding to the folder names) that I want to include in the model.

I also define a few other parameters in the beginning to make adapting as easy as possible.

# list of fruits to modle
fruit_list <- c("Kiwi", "Banana", "Plum", "Apricot", "Avocado", "Cocos", "Clementine", "Mandarine", "Orange",
                "Limes", "Lemon", "Peach", "Plum", "Raspberry", "Strawberry", "Pineapple", "Pomegranate")

# number of output classes (i.e. fruits)
output_n <- length(fruit_list)

# image size to scale down to (original images are 100 x 100 px)
img_width <- 20
img_height <- 20
target_size <- c(img_width, img_height)

# RGB = 3 channels
channels <- 3

# path to image folders
train_image_files_path <- "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/Training/"
valid_image_files_path <- "/Users/shiringlander/Documents/Github/DL_AI/Tutti_Frutti/fruits-360/Validation/"

Loading images

The handy image_data_generator() and flow_images_from_directory() functions can be used to load images from a directory. If you want to use data augmentation, you can directly define how and in what way you want to augment your images with image_data_generator. Here I am not augmenting the data, I only scale the pixel values to fall between 0 and 1.

# optional data augmentation
train_data_gen = image_data_generator(
  rescale = 1/255 #,
  #rotation_range = 40,
  #width_shift_range = 0.2,
  #height_shift_range = 0.2,
  #shear_range = 0.2,
  #zoom_range = 0.2,
  #horizontal_flip = TRUE,
  #fill_mode = "nearest"
)

# Validation data shouldn't be augmented! But it should also be scaled.
valid_data_gen <- image_data_generator(
  rescale = 1/255
  )  

Now we load the images into memory and resize them.

# training images
train_image_array_gen <- flow_images_from_directory(train_image_files_path, 
                                          train_data_gen,
                                          target_size = target_size,
                                          class_mode = "categorical",
                                          classes = fruit_list,
                                          seed = 42)

# validation images
valid_image_array_gen <- flow_images_from_directory(valid_image_files_path, 
                                          valid_data_gen,
                                          target_size = target_size,
                                          class_mode = "categorical",
                                          classes = fruit_list,
                                          seed = 42)
cat("Number of images per class:")
## Number of images per class:
table(factor(train_image_array_gen$classes))
## 
##   0   1   3   4   5   6   7   8   9  10  11  12  13  14  15  16 
## 466 490 492 427 490 490 490 479 490 492 492 894 490 492 490 492
cat("\nClass label vs index mapping:\n")
## 
## Class label vs index mapping:
train_image_array_gen$class_indices
## $Lemon
## [1] 10
## 
## $Peach
## [1] 11
## 
## $Limes
## [1] 9
## 
## $Apricot
## [1] 3
## 
## $Plum
## [1] 12
## 
## $Avocado
## [1] 4
## 
## $Strawberry
## [1] 14
## 
## $Pineapple
## [1] 15
## 
## $Orange
## [1] 8
## 
## $Mandarine
## [1] 7
## 
## $Banana
## [1] 1
## 
## $Clementine
## [1] 6
## 
## $Kiwi
## [1] 0
## 
## $Cocos
## [1] 5
## 
## $Pomegranate
## [1] 16
## 
## $Raspberry
## [1] 13

Define model

Next, we define the keras model.

# number of training samples
train_samples <- train_image_array_gen$n
# number of validation samples
valid_samples <- valid_image_array_gen$n

# define batch size and number of epochs
batch_size <- 32
epochs <- 10

The model I am using here is a very simple sequential convolutional neural net with the following hidden layers: 2 convolutional layers, one pooling layer and one dense layer.

# initialise model
model <- keras_model_sequential()

# add layers
model %>%
  layer_conv_2d(filter = 32, kernel_size = c(3,3), padding = "same", input_shape = c(img_width, img_height, channels)) %>%
  layer_activation("relu") %>%
  
  # Second hidden layer
  layer_conv_2d(filter = 16, kernel_size = c(3,3), padding = "same") %>%
  layer_activation_leaky_relu(0.5) %>%
  layer_batch_normalization() %>%

  # Use max pooling
  layer_max_pooling_2d(pool_size = c(2,2)) %>%
  layer_dropout(0.25) %>%
  
  # Flatten max filtered output into feature vector 
  # and feed into dense layer
  layer_flatten() %>%
  layer_dense(100) %>%
  layer_activation("relu") %>%
  layer_dropout(0.5) %>%

  # Outputs from dense layer are projected onto output layer
  layer_dense(output_n) %>% 
  layer_activation("softmax")

# compile
model %>% compile(
  loss = "categorical_crossentropy",
  optimizer = optimizer_rmsprop(lr = 0.0001, decay = 1e-6),
  metrics = "accuracy"
)

Fit the model; because I used image_data_generator() and flow_images_from_directory() I am now also using the fit_generator() to run the training.

# fit
hist <- model %>% fit_generator(
  # training data
  train_image_array_gen,
  
  # epochs
  steps_per_epoch = as.integer(train_samples / batch_size), 
  epochs = epochs, 
  
  # validation data
  validation_data = valid_image_array_gen,
  validation_steps = as.integer(valid_samples / batch_size),
  
  # print progress
  verbose = 2,
  callbacks = list(
    # save best model after every epoch
    callback_model_checkpoint("../../data/keras/fruits_checkpoints.h5", save_best_only = TRUE),
    # only needed for visualising with TensorBoard
    callback_tensorboard(log_dir = "../../data/logs/fruits_logs")
  )
)

In RStudio we are seeing the output as an interactive plot in the “Viewer” pane but we can also plot it:

plot(hist)

As we can see, the model is quite accurate on the validation data. However, we need to keep in mind that our images are very uniform, they all have the same white background and show the fruits centered and without anything else in the images. Thus, our model will not work with images that don’t look similar as the ones we trained on (that’s also why we can achieve such good results with such a small neural net).

Finally, I want to have a look at the TensorFlow graph with TensorBoard.

tensorboard("../../data/logs/fruits_logs")

That’s all there is to it!

Of course, you could now save your model and/or the weights, visualize the hidden layers, run predictions on test data, etc. For now, I’ll leave it at that, though. 🙂

sessionInfo()
## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS High Sierra 10.13.5
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/de_DE.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] keras_2.1.6
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_0.12.17     compiler_3.5.0   pillar_1.2.3     plyr_1.8.4      
##  [5] base64enc_0.1-3  tools_3.5.0      zeallot_0.1.0    digest_0.6.15   
##  [9] jsonlite_1.5     evaluate_0.10.1  tibble_1.4.2     gtable_0.2.0    
## [13] lattice_0.20-35  rlang_0.2.1      Matrix_1.2-14    yaml_2.1.19     
## [17] blogdown_0.6     xfun_0.1         stringr_1.3.1    knitr_1.20      
## [21] rprojroot_1.3-2  grid_3.5.0       reticulate_1.7   R6_2.2.2        
## [25] rmarkdown_1.9    bookdown_0.7     ggplot2_2.2.1    reshape2_1.4.3  
## [29] magrittr_1.5     whisker_0.3-2    backports_1.1.2  scales_0.5.0    
## [33] tfruns_1.3       htmltools_0.3.6  colorspace_1.3-2 labeling_0.3    
## [37] tensorflow_1.5   stringi_1.2.2    lazyeval_0.2.1   munsell_0.4.3

To leave a comment for the author, please follow the link and comment on their blog: Shirin's playgRound.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Search R-bloggers

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)