# How to: Multinomial regression models in R

April 5, 2011
By

In my last post I looked at binomial choice modelling in R, i.e. how to predict a yes/no decision from other data. Now however I want to look at modelling a more complicated choice, between more than two options. This is known as multinomial choice modelling and R can perform these analyses using the nnet package.

Let’s start by making up some data. The following code creates 1000 data points and creates an arbitrary three-way choice value using some if-else statements.

n <- 1000
df1 <- data.frame(x1=runif(n,0,100),
x2=runif(n,0,100))
df1 <- transform(df1,
y=1+ifelse(100 - x1 - x2 + rnorm(n,sd=10) < 0, 0,
ifelse(100 - 2*x2 + rnorm(n,sd=10) < 0, 1, 2)),
set="Original")

Here’s a plot of the data:

The original data set. Colours represent different choices as a function of variables x1 and x2.

Next we need to fit the model, which is done via R’s nnet package for neural networks. The code below says that we believe the probability of choosing each option to be a function of the variables x1 and x2. The actual model structure is somewhat more complicated but, for simple choice problems, you’re probably safe to use the formula interface with the default options.

# Load the neural network package and fit the model
library(nnet)
mod <- multinom(y ~ x1 + x2, df1)

As with the binomial choice model, predicting new values from such a model can be a little tricky. If we naively use the predict method, the results will be the same every time we run the following command:

predict(mod)

But what we really want is to predict the probabilities associated with each option and then draw a random number to make our actual selection. That way, each time the code runs we will get slightly different choices which better reflects how people actually make decisions. The choice probabilities can be predicted as follows:

predict(mod,df1,"probs")

The result of this command is an n by k matrix, where n is the number of data points being predicted and k is the number of options.

Notice that the sum of each row equals 1, as each matrix entry gives the probability of selecting a given option. Therefore to make a choice, we need to calculate the cumulative probabilities associated with each option. We can then draw a random value between 0 to 1; the option with the greatest cumulative probability below our draw value is our choice. This can be written into a function for easier use.

# Function to predict multinomial logit choice model outcomes
# model = nnet class multinomial model
# newdata = data frame containing new values to predict
predictMNL <- function(model, newdata) {

# Only works for neural network models
if (is.element("nnet",class(model))) {
# Calculate the individual and cumulative probabilities
probs <- predict(model,newdata,"probs")
cum.probs <- t(apply(probs,1,cumsum))

# Draw random values
vals <- runif(nrow(newdata))

# Join cumulative probabilities and random draws
tmp <- cbind(cum.probs,vals)

# For each row, get choice index.
k <- ncol(probs)
ids <- 1 + apply(tmp,1,function(x) length(which(x[1:k] < x[k+1])))

# Return the values
return(ids)
}
}

This function can now be used to predict the outcomes for some new data.

# Make up some new data
n <- 200
df <- data.frame(x1=runif(n,0,100),
x2=runif(n,0,100),
set="Model")
y2 <- predictMNL(mod,df)
df2 <- cbind(df,y=y2)

Created by Pretty R at inside-R.org

The results are plotted here:

Multinomial choice model results. Note how the model captures the uncertainty lying at the interface of the different choice regions.

This method might not provide sufficiently robust results with more complicated choice models, involving many choices and lots of predictor variables. But for most basic problems, you can easily predict multinomial choice models using R’s neural network package and the predictMNL function above.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...