Bayesian Naive Bayes for Classification with the Dirichlet Distribution

July 14, 2014
By

(This article was first published on Shifting sands, and kindly contributed to R-bloggers)

I have a classification task and was reading up on various approaches. In the specific case where all inputs are categorical, one can use “Bayesian Naïve Bayes” using the Dirichlet distribution.
Poking through the freely available text by Barber, I found a rather detailed discussion in chapters 9 and 10, as well as example matlab code for the book, so took it upon myself to port it to R as a learning exercise.
I was not immediately familiar with the Dirichlet distribution, but in this case it appeals to the intuitive counting approach to discrete event probabilities.
In a nutshell we use the training data to learn the posterior distribution, which turns out to be counts of how often a given event occurs, grouped by class, feature and feature state.
Prediction is a case of counting events in the test vector. The more this count differs from the per-class trained counts, the lower the probability the current candidate class is a match.
Anyway, there are three files. The first is a straightforward port of Barber’s code, but this wasn’t very R-like, and in particular only seemed to handle input features with the same number of states.
I developed my own version that expects everything to be represented as factors. It is all a bit rough and ready but appears to work and there is a test/example script up here. As a bigger test I ran it on a sample  car evaluation data set from here, the confusion matrix is as follows:
testY   acc good unacc vgood
acc    83    3    29     0
good   16    5     0     0
unacc  17    0   346     0
vgood  13    0     0     6
That’s it for now. Comments/feedback appreciated. You can find me on twitter here

Everything in one directory (with data) here

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...