K-fold cross-validation in Stan

[This article was first published on R Programming – DataScience+, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Comparing multiple models is one of the core but also one of the trickiest element of data analysis. Under a Bayesian framework the loo package in R allows you to derive (among other things) leave-one-out cross-validation metrics to compare the predictive abilities of different models.

Cross-validation is basically: (i) separating the data into chunks, (ii) fitting the model while holding out one chunk at a time, (iii) evaluating the probability density of the held-out chunk of data based on the parameter estimates, (iv) derive some metrics from the likelihood of the held-out data. The basic idea is that if your model does a good job, then it can predict the value of held-out data pretty well.

We can separate the data into a different number of chunks, or folds, from 2 up to N, the number of data points. Holding out one data point at a time is called leave-one-out cross-validation and can be computationally costly (you need to fit the model N times), fortunately Vehtari and colleagues have been developing methods to estimate leave-one-out cross validation without having to do expensive computation. Now, as with every estimation, it sometime cannot give reliable results and one could/should then fall back to standard K-fold cross-validation by fitting the model multiple time.

The aim of this post is to show one simple example of K-fold cross-validation in Stan via R, so that when loo cannot give you reliable estimates, you may still derive metrics to compare models.

The Stan code

Below is the Stan code for a simple linear normal regression allowing K-fold cross-validation

Standard normal regression for any number of predictor variables
with weakly informative priors on the betas and on the standard deviation
    int<lower=1> N; //number of observations
    int<lower=1> K; //number of predictor variables
    matrix[N,K] X; //the model matrix including intercept
    vector[N] y; //the response variable

  int<lower=0,upper=1> holdout[N]; 
  //index whether the observation should be held out (1) or used (0)
    vector[K] beta; //the regression parameters
    real<lower=0> sigma;
    vector[N] mu; //the linear predictor
    mu = X * beta; //the regression
    beta[1] ~ normal(0,10);
    beta[2:K] ~ normal(0,5);
    sigma ~ normal(0,10);
    //likelihood holding out some data
    for(n in 1:N){
        if(holdout[n] == 0){
            target += normal_lpdf(y[n] | mu[n], sigma);
generated quantities{
        vector[N] log_lik;
    for (n in 1:N){
        log_lik[n] = normal_lpdf(y[n] | X[n,] * beta, sigma);

The key element here is to pass as data a vector of 0s and 1s as long as the observed data indicating whether each point should be held-out during model estimation or not. Estimating the likelihood is done by going through each data point, asking if it should be held-out, if not then this data point is used to increment the likelihood based on the model parameter.

Below in the generated quantities segment, we are retrieving the probability density (the height of the probability distribution) for each data, also the one held-out.

The best way to go on if you want to try out the code is to copy/paste the model code into a .stan file.

Simulating some data and fitting the models

Let's start by setting some basic quantities and creating the held-out indeces.

#first load the libraries
rstan_options(auto_write = TRUE)

N <- 100 #sample size
K <- 2 #number of predictors
n_fold <- 10 #number of folds

#create 10 folds of data
hh <- sample(1:N,size = N,replace = FALSE)
holdout_10 <- matrix(0,nrow=N,ncol=n_fold)
for(i in 1:n_fold){
  id <- seq(1,100,by=10)
  holdout_10[hh[id[i]:(id[i] + 9)],i] <- 1
#some sanity checks

#turn into a list
holdout_10 <- split(holdout_10,rep(1:ncol(holdout_10),each=nrow(holdout_10)))

Randomly assigning each data point to a different fold is the trickiest part of the data preparation in K-fold cross-validation. What I basically did is randomly sample N times with no replacement from the data point index (the object hh), and put the first 10 index in the first fold, the subsequent 10 in the second fold and so on.

Note that if there is some kind of grouping structure in your data that you want to be taken care of while splitting it, it might get a bit more complex, there is certainly some clever functions out there that might help you out (such as in the caret package).

Now we can simulate some data and create the data object to be passed to Stan, already note that we will need one data object per fold, so a list might be an easy way to combine these:

X <- cbind(rep(1,N),runif(N,-2,2))
y <- rnorm(N,X %*% c(1, 0.5),1)

#the basic data object
data_m <- list(N=N,K=K,X=X,y=y)
#create a list of data list
data_l <- rep(list(data_m),10)
#add the holdout index to it
for(i in 1:10) data_l[[i]]$holdout <- holdout_10[[i]]

We are now ready to fit the model to each fold, we could just loop through the folds but this is inefficient, rather I am going to use the functions available in this link and pasted at the end of this post. Basically stan_kfold output a list of stanfit objects (one for each fold), extract_log_lik_K output a S x N matrix where S is the number of posterior draws where each element is the log-likelihood when the data point was held-out and kfold compute the expected log pointwise predictive density elpd. Basically, the elpd is the height (density) of the probability distribution, given the model parameters, at the data point (pointwise) that were held-out (predictive).

#run the functions
ss <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_l,chains=4,cores=2)
ee <- extract_log_lik_K(ss,holdout_10)
kk <- kfold(ee) 
#compare with official loo results
ll <- loo(ee)

The ee matrix can actually also be used as input to loo to get some more metrics.

All this is nice and fine but these metrics are only relative and only truly make sense when comparing between different models fitted tot he same data. So let's fit two additional models, one overly complex and one overly simple:

# fit a too complex and a too simple model
X_comp <- cbind(X,runif(N,-2,2))
X_simp <- X[,1,drop=FALSE]

# new data
data_comp <- data_l
for(i in 1:10){
  data_comp[[i]]$X <- X_comp
  data_comp[[i]]$K <- 3
data_simp <- data_l 
for(i in 1:10){
  data_simp[[i]]$X <- X_simp
  data_simp[[i]]$K <- 1

#fit the new models
ss_comp <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_comp,chains=4,cores=2)
ss_simp <- stan_kfold(file="Documents/PostDoc_Ghent/STAN_stuff/Models/normal_model_basic_cv.stan",data_simp,chains=4,cores=2)

ee_comp <- extract_log_lik_K(ss_comp,holdout_10)
ee_simp <- extract_log_lik_K(ss_simp,holdout_10)

#compare the models
#             elpd_diff elpd_loo se_elpd_loo p_loo  se_p_loo looic  se_looic
#loo(ee)         0.0    -148.9      9.5         4.3    1.0    297.7   19.1  
#loo(ee_comp)   -1.8    -150.7      9.6         5.5    1.3    301.4   19.2  
#loo(ee_simp)  -11.2    -160.1      8.9         2.8    0.8    320.1   17.8  

The compare functions ouput quite some information, let's go through it.

The first column is the difference in the summed expected log pointwise predictive density, the difference between the second and the first model is -1.8, meaning that the first model as slightly higher predictive density. The second column is the summed expected log pointwise predictive density, the values in the first column are differences from this one. The third column is the standard error in the expected log poitwise predictive density. The fourth and fifth columns is the effective number of parameter in the model and its standard error. The sixth column is the -2 * the second column, putting the elpd on the deviance scale so being similar to other information criteria metrics such as AIC or DIC. And the final column is the standard error of the information criteria.

From this output one could argue that the model with two predictors is definitively superior to the intercept-only model, and his slightly better than the model with 3 parameters.

Voila, do remember that model selection is very tricky, I would not encourage using information criteria or cross-validation metrics to decide whether one should include that particular interaction or this extra covariate. To my mind, these type of comparison are relevant to compare different model structure such as simple regression vs hierarchical models vs spatial effects vs autoregressive structure. Or to compare models based on a different set of covariates such as to compare if plant growth can be better predicted by trait information vs environmental information.

Happy folding

The function code

Below is the code of the function used to fit the models and extract the information:

#functions slightly modified from: https://github.com/stan-dev/stancon_talks/blob/master/2017/Contributed-Talks/07_nicenboim/kfold.Rmd

#function to parrallelize all computations
#need at least two chains !!!
stan_kfold <- function(file, list_of_datas, chains, cores,...){
  badRhat <- 1.1 # don't know why we need this?
  n_fold <- length(list_of_datas)
  model <- stan_model(file=file)
  # First parallelize all chains:
  sflist <- 
    pbmclapply(1:(n_fold*chains), mc.cores = cores, 
                 # Fold number:
                 k <- ceiling(i / chains)
                 s <- sampling(model, data = list_of_datas[[k]], 
                               chains = 1, chain_id = i,...)

  # Then merge the K * chains to create K stanfits:
  stanfit <- list()
  for(k in 1:n_fold){
    inchains <- (chains*k - (chains - 1)):(chains*k)
  #  Merge `chains` of each fold
    stanfit[[k]] <- sflist2stanfit(sflist[inchains])

#extract log-likelihoods of held-out data
extract_log_lik_K <- function(list_of_stanfits, list_of_holdout, ...){
  K <- length(list_of_stanfits)
  list_of_log_liks <- plyr::llply(1:K, function(k){
  # `log_lik_heldout` will include the loglike of all the held out data of all the folds.
  # We define `log_lik_heldout` as a (samples x N_obs) matrix
  # (similar to each log_lik matrix)
  log_lik_heldout <- list_of_log_liks[[1]] * NA
  for(k in 1:K){
    log_lik <- list_of_log_liks[[k]]
    samples <- dim(log_lik)[1] 
    N_obs <- dim(log_lik)[2]
    # This is a matrix with the same size as log_lik_heldout
    # with 1 if the data was held out in the fold k
    heldout <- matrix(rep(list_of_holdout[[k]], each = samples), nrow = samples)
    # Sanity check that the previous log_lik is not being overwritten:
      warning("Heldout log_lik has been overwritten!!!!")
    # We save here the log_lik of the fold k in the matrix:
    log_lik_heldout[heldout==1] <- log_lik[heldout==1]

#compute ELPD
kfold <- function(log_lik_heldout)  {
  logColMeansExp <- function(x) {
    # should be more stable than log(colMeans(exp(x)))
    S <- nrow(x)
    colLogSumExps(x) - log(S)
  # See equation (20) of @VehtariEtAl2016
  pointwise <-  matrix(logColMeansExp(log_lik_heldout), ncol= 1)
  colnames(pointwise) <- "elpd"
  # See equation (21) of @VehtariEtAl2016
  elpd_kfold <- sum(pointwise)
  se_elpd_kfold <-  sqrt(ncol(log_lik_heldout) * var(pointwise))
  out <- list(
    pointwise = pointwise,
    elpd_kfold = elpd_kfold,
    se_elpd_kfold = se_elpd_kfold)
  #structure(out, class = "loo")

    Related Post

    1. Automated Text Feature Engineering using textfeatures in R
    2. Explaining Keras image classification models with LIME
    3. Image classification with keras in roughly 100 lines of code
    4. R vs Python: Image Classification with Keras
    5. Update: Can we predict flu outcome with Machine Learning in R?

    To leave a comment for the author, please follow the link and comment on their blog: R Programming – DataScience+.

    R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
    Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

    Never miss an update!
    Subscribe to R-bloggers to receive
    e-mails with the latest R posts.
    (You will not see this message again.)

    Click here to close (This popup will not appear again)