QuickTip: Utilizing Machine Learning Methods to Identify Important Variables

February 2, 2015

(This article was first published on Proven Inconclusive, and kindly contributed to R-bloggers)

Machine Learning is the field of scientific study that concentrates on induction algorithms and on other algorithms that can be said to “learn.” [1]

In order to identify important variables in a multivariate dataset one can utilize machine learning methods. There are many different machine learning algorithms for different tasks. One common task is to decide if a feature vector belongs to a certain class. This can be done with a random forest [2] classifier. In order to do so, one has to train the classifier with training data first. Then the classifier can be used to predict the class of other feature vectors. For demonstration purposes we will use the iris data set. The following R code loads the "randomForest" library and trains a classifier (forest) with the iris data set. The "Species" column is set as the training label.

> library(randomForest)
> forest = randomForest(Species~., data=iris, importance=TRUE)
> forest

 randomForest(formula=Species ~ ., data=iris, importance=TRUE) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 2

        OOB estimate of  error rate: 4.67%
Confusion matrix:
           setosa versicolor virginica class.error
setosa         50          0         0        0.00
versicolor      0         47         3        0.06
virginica       0          4        46        0.08

In this example the classifier achieves an out-of-bag (oob) error rate of 4.67%. There is no need for other tests, such as cross-validation, to get an unbiased estimate of the test set error as each tree is created with a different bootstrap sample [2].

The classifier saves information on feature importance ("importance=TRUE"). We can use this information in order to identify potentially import variables in the data set. The following R code extracts this information from the classifier and visualizes the data using ggplot2.

> library(ggplot2)
> forest.importance = as.data.frame(importance(forest, scale=FALSE))
> forest.importance = forest.importance[,1🙁ncol(forest.importance)-2)]
> forest.importance$mean = rowMeans(forest.importance)
> forest.importance
Table 1: Feature importance table with calculated mean column.
feature setosa versicolor virginica mean
Sepal.Length 0.031 0.025 0.046 0.034
Sepal.Width 0.008 0.003 0.011 0.007
Petal.Length 0.349 0.322 0.324 0.332
Petal.Width 0.305 0.289 0.267 0.287
ggplot(forest.importance, aes(x=row.names(forest.importance), y=mean)) +
  ylab('mean relative feature importance') +
  xlab('feature') +
Figure 1: Mean relative feature importance learned by a random forest classifier on the iris data set.
Figure 1: Mean relative feature importance learned by a random forest classifier on the iris data set.


1. Glossary of terms. Machine Learning [Internet] 1998 Feb [cited 2015 Feb 3TZ];30(2-3):271–274. Available from: http://link.springer.com/article/10.1023/A%3A1017181826899

2. Breiman L. Random forests. Machine Learning [Internet] 2001 Oct [cited 2015 Feb 3TZ];45(1):5–32. Available from: http://link.springer.com/article/10.1023/A%3A1010933404324

To leave a comment for the author, please follow the link and comment on their blog: Proven Inconclusive.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...

If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.


Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)