In real world, its not uncommon to come across unbalanced data sets where, you might have class A with 90 observations and class B with 10 observations. One of the rules in machine learning is, its important to balance out the data set or at least get it close to balance it. The main reason for this is to give equal priority to each class in laymen terms.
Let’s consider the above example, where we had class A with 90 observations and class B with 10 observations. If we predict the entire data set as class A, we will achieve an accuracy of 90% which seems really not bad for a classification model. But, in reality this model is very poor. To expand on this further, if we consider a quality control process and the model predicts everything as good, then we end up getting a lot of warranty claims.
So, what metrics should we be looking at to determine how our model is performing? Let’s take an example to do a walk through.
Example 1: Balanced Data Set
For this we will consider iris data set. Since, the “iris” data set is already balanced as shown in the below figure, the model is bound to perform well.
# load the libraries library(e1071) library(caret) # load iris dataset data("iris") # create a sample for test and train sample = createDataPartition(iris$Species, p = 0.8) # create test and train set train = iris[sample$Resample1,] test = iris[-sample$Resample1,] # build a svm model svm_model = svm(x = train[,1:4], y = train[,5], type = "C-classification") # perform perdiction on test data set predictions = predict(svm_model, test[,1:4]) # print confusion matrix confusionMatrix(test$Species, predictions)
In the below confusion matrix, we could notice that all the classes were predicted accurately and hence the accuracy is 100%. No Information Rate (NIR) can be explained as if we randomly guess which class the test observations belong to, forms NIR. Ideally, NIR should be less than Accuracy. Here, NIR is 0.33 which is less than accuracy indicating a good model. Also, we notice that NIR = 0.33 indicating equal balanced classes.
Next metric that we want to look at is balanced accuracy. The balanced accuracy in binary and multi-class classification problems to deal with imbalanced data sets. It is defined as the average of recall obtained on each class. The best value is 1 and the worst value is 0. From the below result, we notice that for all 3 classes Balanced Accuracy is 100% indicating a well balanced model.
Finally, looking at detection rate we notice that detecting setosa, versicolor or virginica are all same at 0.33.
Example 2: Unbalanced Data Set
For the next example, we will take the same iris data set and create an unbalanced data set as shown below.
# create an unbalanced data set data_1 = rbind(iris[1:50,], iris[51:80, ], iris[101:110, ]) # create a sample for test and train sample = createDataPartition(data_1$Species, p = 0.8) # create test and train set train = data_1[sample$Resample1,] test = data_1[-sample$Resample1,] # build a svm model svm_model = svm(x = train[,1:4], y = train[,5], type = "C-classification") # perform perdiction on test data set predictions = predict(svm_model, test[,1:4]) # print confusion matrix confusionMatrix(test$Species, predictions)
In the below results, we see that the accuracy is 94% which is not bad for a result. NIR is 0.555, which is indeed less than accuracy. But NIR is significantly higher than the previous example indicating randomly assigning classes for a unbalanced data set would generate higher accuracy than a balanced data set.
Next, we look at balanced accuracy where, class setosa has 1 indicating a perfectly balanced and versicolor and virginica are less than 1.
Finally, looking at detection rate we notice that detecting setosa is higher than versicolor or virginica. Ideally, detection rate for each of the classes should be close to each other or same.
From the above examples, we notice that having a balanced data set for a model would generate higher accuracy models, higher balanced accuracy and balanced detection rate. Hence, its important to have a balanced data set for a classification model.
Let me know in the comments below if ever deal with unbalanced data sets and how you deal with unbalanced data set.
If you liked the above article, check out my other articles as well.