# K-fold cross-validation in Stan

**R Programming – DataScience+**, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)

Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Comparing multiple models is one of the core but also one of the trickiest element of data analysis. Under a Bayesian framework the loo package in R allows you to derive (among other things) leave-one-out cross-validation metrics to compare the predictive abilities of different models.

Cross-validation is basically: (i) separating the data into chunks, (ii) fitting the model while holding out one chunk at a time, (iii) evaluating the probability density of the held-out chunk of data based on the parameter estimates, (iv) derive some metrics from the likelihood of the held-out data. The basic idea is that if your model does a good job, then it can predict the value of held-out data pretty well.

We can separate the data into a different number of chunks, or folds, from 2 up to N, the number of data points. Holding out one data point at a time is called leave-one-out cross-validation and can be computationally costly (you need to fit the model N times), fortunately Vehtari and colleagues have been developing methods to estimate leave-one-out cross validation without having to do expensive computation. Now, as with every estimation, it sometime cannot give reliable results and one could/should then fall back to standard K-fold cross-validation by fitting the model multiple time.

The aim of this post is to show one simple example of K-fold cross-validation in Stan via R, so that when loo cannot give you reliable estimates, you may still derive metrics to compare models.

## The Stan code

Below is the Stan code for a simple linear normal regression allowing K-fold cross-validation

/* Standard normal regression for any number of predictor variables with weakly informative priors on the betas and on the standard deviation */ data{ intN; //number of observations int K; //number of predictor variables matrix[N,K] X; //the model matrix including intercept vector[N] y; //the response variable int holdout[N]; //index whether the observation should be held out (1) or used (0) } parameters{ vector[K] beta; //the regression parameters real sigma; } model{ vector[N] mu; //the linear predictor mu = X * beta; //the regression //priors beta[1] ~ normal(0,10); beta[2:K] ~ normal(0,5); sigma ~ normal(0,10); //likelihood holding out some data for(n in 1:N){ if(holdout[n] == 0){ target += normal_lpdf(y[n] | mu[n], sigma); } } } generated quantities{ vector[N] log_lik; for (n in 1:N){ log_lik[n] = normal_lpdf(y[n] | X[n,] * beta, sigma); } }

The key element here is to pass as data a vector of 0s and 1s as long as the observed data indicating whether each point should be held-out during model estimation or not. Estimating the likelihood is done by going through each data point, asking if it should be held-out, if not then this data point is used to increment the likelihood based on the model parameter.

Below in the generated quantities segment, we are retrieving the probability density (the height of the probability distribution) for each data, also the one held-out.

The best way to go on if you want to try out the code is to copy/paste the model code into a .stan file.

## Simulating some data and fitting the models

Let's start by setting some basic quantities and creating the held-out indeces.

#first load the libraries library(rstan) rstan_options(auto_write = TRUE) library(pbmcapply) library(loo) N <- 100 #sample size K <- 2 #number of predictors n_fold <- 10 #number of folds #create 10 folds of data hh <- sample(1:N,size = N,replace = FALSE) holdout_10 <- matrix(0,nrow=N,ncol=n_fold) for(i in 1:n_fold){ id <- seq(1,100,by=10) holdout_10[hh[id[i]:(id[i] + 9)],i] <- 1 } #some sanity checks #apply(holdout_10,1,sum) #apply(holdout_10,2,sum) #turn into a list holdout_10 <- split(holdout_10,rep(1:ncol(holdout_10),each=nrow(holdout_10)))

Randomly assigning each data point to a different fold is the trickiest part of the data preparation in K-fold cross-validation. What I basically did is randomly sample N times with no replacement from the data point index (the object `hh`

), and put the first 10 index in the first fold, the subsequent 10 in the second fold and so on.

Note that if there is some kind of grouping structure in your data that you want to be taken care of while splitting it, it might get a bit more complex, there is certainly some clever functions out there that might help you out (such as in the caret package).

Now we can simulate some data and create the data object to be passed to Stan, already note that we will need one data object per fold, so a list might be an easy way to combine these:

X <- cbind(rep(1,N),runif(N,-2,2)) y <- rnorm(N,X %*% c(1, 0.5),1) #the basic data object data_m <- list(N=N,K=K,X=X,y=y) #create a list of data list data_l <- rep(list(data_m),10) #add the holdout index to it for(i in 1:10) data_l[[i]]$holdout <- holdout_10[[i]]

We are now ready to fit the model to each fold, we could just loop through the folds but this is inefficient, rather I am going to use the functions available in this link and pasted at the end of this post. Basically `stan_kfold`

output a list of `stanfit`

objects (one for each fold), `extract_log_lik_K`

output a S x N matrix where S is the number of posterior draws where each element is the log-likelihood when the data point was held-out and `kfold`

compute the expected log pointwise predictive density `elpd`

. Basically, the `elpd`

is the height (density) of the probability distribution, given the model parameters, at the data point (pointwise) that were held-out (predictive).

#run the functions ss <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_l,chains=4,cores=2) ee <- extract_log_lik_K(ss,holdout_10) kk <- kfold(ee) #compare with official loo results ll <- loo(ee)

The `ee`

matrix can actually also be used as input to loo to get some more metrics.

All this is nice and fine but these metrics are only relative and only truly make sense when comparing between different models fitted tot he same data. So let's fit two additional models, one overly complex and one overly simple:

# fit a too complex and a too simple model X_comp <- cbind(X,runif(N,-2,2)) X_simp <- X[,1,drop=FALSE] # new data data_comp <- data_l for(i in 1:10){ data_comp[[i]]$X <- X_comp data_comp[[i]]$K <- 3 } data_simp <- data_l for(i in 1:10){ data_simp[[i]]$X <- X_simp data_simp[[i]]$K <- 1 } #fit the new models ss_comp <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_comp,chains=4,cores=2) ss_simp <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_simp,chains=4,cores=2) ee_comp <- extract_log_lik_K(ss_comp,holdout_10) ee_simp <- extract_log_lik_K(ss_simp,holdout_10) #compare the models compare(loo(ee),loo(ee_comp),loo(ee_simp))</code> <em>#output: # elpd_diff elpd_loo se_elpd_loo p_loo se_p_loo looic se_looic #loo(ee) 0.0 -148.9 9.5 4.3 1.0 297.7 19.1 #loo(ee_comp) -1.8 -150.7 9.6 5.5 1.3 301.4 19.2 #loo(ee_simp) -11.2 -160.1 8.9 2.8 0.8 320.1 17.8 </em></pre> <p>The compare functions ouput quite some information, let's go through it.</p> <p>The first column is the difference in the summed expected log pointwise predictive density, the difference between the second and the first model is -1.8, meaning that the first model as slightly higher predictive density. The second column is the summed expected log pointwise predictive density, the values in the first column are differences from this one. The third column is the standard error in the expected log poitwise predictive density. The fourth and fifth columns is the effective number of parameter in the model and its standard error. The sixth column is the -2 * the second column, putting the elpd on the deviance scale so being similar to other information criteria metrics such as AIC or DIC. And the final column is the standard error of the information criteria.</p> <p>From this output one could argue that the model with two predictors is definitively superior to the intercept-only model, and his slightly better than the model with 3 parameters.</p> <p>Voila, do remember that model selection is very tricky, I would not encourage using information criteria or cross-validation metrics to decide whether one should include that particular interaction or this extra covariate. To my mind, these type of comparison are relevant to compare different model structure such as simple regression vs hierarchical models vs spatial effects vs autoregressive structure. Or to compare models based on a different set of covariates such as to compare if plant growth can be better predicted by trait information vs environmental information.</p> <p>Happy folding</p> <h2>The function code</h2> <p>Below is the code of the function used to fit the models and extract the information:</p> <pre><code class="r">#functions slightly modified from: https://github.com/stan-dev/stancon_talks/blob/master/2017/Contributed-Talks/07_nicenboim/kfold.Rmd #function to parrallelize all computations #need at least two chains !!! stan_kfold <- function(file, list_of_datas, chains, cores,...){ library(pbmcapply) badRhat <- 1.1 # don't know why we need this? n_fold <- length(list_of_datas) model <- stan_model(file=file) # First parallelize all chains: sflist <- pbmclapply(1:(n_fold*chains), mc.cores = cores, function(i){ # Fold number: k <- ceiling(i / chains) s <- sampling(model, data = list_of_datas[[k]], chains = 1, chain_id = i,...) return(s) }) # Then merge the K * chains to create K stanfits: stanfit <- list() for(k in 1:n_fold){ inchains <- (chains*k - (chains - 1)):(chains*k) # Merge `chains` of each fold stanfit[[k]] <- sflist2stanfit(sflist[inchains]) } return(stanfit) } #extract log-likelihoods of held-out data extract_log_lik_K <- function(list_of_stanfits, list_of_holdout, ...){ require(loo) K <- length(list_of_stanfits) list_of_log_liks <- plyr::llply(1:K, function(k){ extract_log_lik(list_of_stanfits[[k]],...) }) # `log_lik_heldout` will include the loglike of all the held out data of all the folds. # We define `log_lik_heldout` as a (samples x N_obs) matrix # (similar to each log_lik matrix) log_lik_heldout <- list_of_log_liks[[1]] * NA for(k in 1:K){ log_lik <- list_of_log_liks[[k]] samples <- dim(log_lik)[1] N_obs <- dim(log_lik)[2] # This is a matrix with the same size as log_lik_heldout # with 1 if the data was held out in the fold k heldout <- matrix(rep(list_of_holdout[[k]], each = samples), nrow = samples) # Sanity check that the previous log_lik is not being overwritten: if(any(!is.na(log_lik_heldout[heldout==1]))){ warning("Heldout log_lik has been overwritten!!!!") } # We save here the log_lik of the fold k in the matrix: log_lik_heldout[heldout==1] <- log_lik[heldout==1] } return(log_lik_heldout) } #compute ELPD kfold <- function(log_lik_heldout) { library(matrixStats) logColMeansExp <- function(x) { # should be more stable than log(colMeans(exp(x))) S <- nrow(x) colLogSumExps(x) - log(S) } # See equation (20) of @VehtariEtAl2016 pointwise <- matrix(logColMeansExp(log_lik_heldout), ncol= 1) colnames(pointwise) <- "elpd" # See equation (21) of @VehtariEtAl2016 elpd_kfold <- sum(pointwise) se_elpd_kfold <- sqrt(ncol(log_lik_heldout) * var(pointwise)) out <- list( pointwise = pointwise, elpd_kfold = elpd_kfold, se_elpd_kfold = se_elpd_kfold) #structure(out, class = "loo") return(out) }

Related Post

- Automated Text Feature Engineering using textfeatures in R
- Explaining Keras image classification models with LIME
- Image classification with keras in roughly 100 lines of code
- R vs Python: Image Classification with Keras
- Update: Can we predict flu outcome with Machine Learning in R?

**leave a comment**for the author, please follow the link and comment on their blog:

**R Programming – DataScience+**.

R-bloggers.com offers

**daily e-mail updates**about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.

Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.