Build an MNIST Classifier With Random Forests
Simple image classification tasks don’t require deep learning models. Today you’ll learn how to build a handwritten digit classifier from scratch with R and Random Forests and what are the “gotchas” in the process.
Are you completely new to machine learning? Start with this guide on linear regression.
Today’s article is structured as follows:
MNIST is the “hello world” of image classification datasets. It contains tens of thousands of handwritten digits ranging from zero to nine. Each image is of size 28×28 pixels.
The following image displays a couple of handwritten digits from the dataset:
As you can see, these images should be relatively easy to classify. The most common approach is to use a neural network with a couple of convolutional layers (to detect patterns), followed by a couple of fully connected layers and an output layer (with ten nodes), but there’s a simpler approach.
MNIST is a toy dataset, so you can replace the neural network architecture with something simpler, like random forests. This will require image flattening – from 28×28 to 1×784. In a nutshell, you’ll end up with a tabular dataset of 784 columns (one for each pixel). More on the pros and cons of this approach in a bit.
Let’s load the dataset next and talk strategy afterward.
You can download both training and testing sets on this link. It’s a CSV format instead of PNG, which eliminates the transformation process.
Keep in mind – the CSV’s don’t contain column names, so you’ll have to specify
col_names = FALSE when loading the files.
The following code snippet loads both sets and extracts the labels (actual digit class). Further, the snippet prints the first 20 labels from the training set:
The results are shown in the following image:
Note the Levels column – it’s here because you’ve converted digits to factors with the
Finally, let’s see how many records there are for each digit:
Here are the results:
As you can see, the values aren’t identical and range from around 5400 to 6700, but that shouldn’t be too big of an issue for the classifier.
Next, let’s see how you can train the model. Spoiler alert – it will require only a single line of code.
You’ll use the Random Forests algorithm to build a handwritten digit classifier. As discussed before, this has some pros and cons when comparing to the neural network classifiers.
The biggest pro is the training speed – the training process will finish in a minute or so on CPU, whereas the training process for neural networks can take anywhere from minutes (GPU) to hours (CPU) – depending on the model architecture and your hardware.
The downside of using Random Forests (or any other machine learning algorithm) is the loss of 2D information. When you flatten the image (go from 28×28 to 1×784), you’re losing information on surrounding pixels. A convolution operation is a go-to approach for any more demanding image classification problem.
Still, the Random Forest classifier should suit you fine on the MNIST dataset.
The following code snippet shows you how to import the library, train the model, and print the results. The execution will take a minute or so, depending on your hardware:
The results are shown in the image below:
As you can see, the confusion matrix for the training set is visible from the image above, alongside the classification errors.
We’ll talk more about model evaluation in the next section.
The first metric you’ll check is the overall accuracy. The random forest model gives you access to the error rate among all of the classes, so you can calculate the mean and subtract the result from 1.
1 – the error rate represents the accuracy. You can use the following code snippet to get the overall accuracy:
The results are shown in the following image:
As you can see, the accuracy is around 95% overall. Not bad for a random forest classifier model.
Next, let’s explore the error rate for every digit. Maybe some numbers are easier to classify than the others, so let’s find out. You’ll need the
dplyr package for this calculation. You’ll use it to select appropriate columns and calculate their means with the
colMeans() function. Here’s the entire code snippet:
The results are shown below:
As you can see, zeros and ones seem to be the easiest to classify, and fives and eights the hardest. It makes sense if you think about it.
This article demonstrated how you could use a simple machine learning algorithm for image classification. Keep in mind – this shouldn’t be a go-to approach for more complex images. Just imagine you had 512×512 images. Flattening them would result in a dataset with more than 260K columns.
You should always use convolution operations when dealing with more complex image classification, as this operation will detect features of certain objects more accurately than a machine-learning-based approach.
If you want to implement machine learning in your organization, you can always reach out to Appsilon for help.
- Machine Learning with R: A Complete Guide to Linear Regression
- Machine Learning with R: A Complete Guide to Logistic Regression
- Machine Learning with R: A Complete Guide to Decision Trees
- What Can I Do With R? 6 Essential R Packages for Programmers
- YOLO Algorithm and YOLO Object Detection: An Introduction
Appsilon is hiring for remote roles! See our Careers page for all open positions, including R Shiny Developers, Fullstack Engineers, Frontend Engineers, a Senior Infrastructure Engineer, and a Community Manager. Join Appsilon and work on groundbreaking projects with the world’s most influential Fortune 500 companies.