Cross-Validation for Predictive Analytics Using R

[This article was first published on MilanoR, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Introduction

Since ancient times, humankind has always avidly sought a way to predict the future. One of the most widely known examples of this kind of activity in the past is the Oracle of Delphi, who dispensed previews of the future to her petitioners in the form of divine inspired prophecies1. In the modern days, the desire to know the future is still of interest to many of us, even if my feeling is that the increasing rapidity of technology innovations we observe everyday has somewhat lessened this instinct: things that few years ago seemed futuristic are now available to the great mass (e.g. the World Wide Web).

Among the many areas of the human being where predictions are highly needed there is business decision making. The tools for formulating predictions about quantities of interest are commonly known as predictive analytics, which is itself an essential part of data science. At the heart of any prediction there is always a model, which typically depends on some unknown structural parameters (e.g. the coefficients of a regression model) as well as on one or more tuning parameters (e.g. the number of basis functions in a smoothing spline or the degree of a polynomial). The former are commonly estimated using a sample of data, while the latter have to be chosen to guarantee that the model itself provides predictions which are accurate enough. Tuning parameters usually regulate the model complexity and hence are a key ingredient for any predictive task. In this blog entry we focus on the most common strategy for eliciting reasonable values for the tuning parameters, the cross-validation approach.

The Bias-Variance Dilemma

The reason why one should care about the choice of the tuning parameter values is because these are intimately linked with the accuracy of the predictions returned by the model. What an analyst typically wants is a model that is able to predict well samples that have not been used for estimating the structural parameters (the so called training sample). In other words, a predictive model is considered good when it is capable of predicting previously unseen samples with high accuracy. The accuracy of a model’s predictions is usually gauged using a loss function. Popular choices for the loss functions are the mean-squared error for continuous outcomes, or the 0-1 loss for a categorical outcome2.

At this point, it is important to distinguish between different prediction error concepts:

  • the training error, which is the average loss over the training sample,
  • the test error, the prediction error over an independent test sample.

The training error gets smaller as long as the predicted responses are close to the observed responses, and will get larger if for some of the observations, the predicted and observed responses differ substantially. The training error is calculated using the training sample used to fit the model. Clearly, we shouldn’t care too much about the model’s predictive accuracy on the training data. On the contrary, we would like to assess the model’s ability to predict observations never seen during estimation. The test error provides a measure of this ability. In general, one should select the model corresponding to the lowest test error.

The R code below implements these idea via simulated data. In particular, I simulate 100 training sets each of size 50 from a polynomial regression model, and for each I fit a sequence of cubic spline models with degrees of freedom from 1 to 30.

# Generate the training and test samples
seed <- 1809
set.seed(seed)

gen_data <- function(n, beta, sigma_eps) {
    eps <- rnorm(n, 0, sigma_eps)
    x <- sort(runif(n, 0, 100))
    X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE))
    y <- as.numeric(X %*% beta + eps)
    
    return(data.frame(x = x, y = y))
}

# Fit the models
require(splines)

n_rep <- 100
n_df <- 30
df <- 1:n_df
beta <- c(5, -0.1, 0.004, -3e-05)
n_train <- 50
n_test <- 10000
sigma_eps <- 0.5

xy <- res <- list()
xy_test <- gen_data(n_test, beta, sigma_eps)
for (i in 1:n_rep) {
    xy[[i]] <- gen_data(n_train, beta, sigma_eps)
    x <- xy[[i]][, "x"]
    y <- xy[[i]][, "y"]
    res[[i]] <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf)))
}

The next plot shows the first simulated training sample together with three fitted models corresponding to cubic splines with 1 (green line), 4 (orange line) and 25 (blue line) degrees of freedom respectively. These numbers have been chosen to show the full set of possibilities one may encounter in practice, i.e., either a model with low variability but high bias (degrees of freedom = 1), or a model with high variability but low bias (degrees of freedom = 25), or a model which tries to find a compromise between bias and variance (degrees of freedom = 4).

# Plot the data
x <- xy[[1]]$x
X <- cbind(1, poly(x, degree = (length(beta) - 1), raw = TRUE))
y <- xy[[1]]$y
plot(y ~ x, col = "gray", lwd = 2)
lines(x, X %*% beta, lwd = 3, col = "black")
lines(x, fitted(res[[1]][[1]]), lwd = 3, col = "palegreen3")
lines(x, fitted(res[[1]][[4]]), lwd = 3, col = "darkorange")
lines(x, fitted(res[[1]][[25]]), lwd = 3, col = "steelblue")
legend(x = "topleft", legend = c("True function", "Linear fit (df = 1)", "Best model (df = 4)", 
    "Overfitted model (df = 25)"), lwd = rep(3, 4), col = c("black", "palegreen3", 
    "darkorange", "steelblue"), text.width = 32, cex = 0.85)

cross-validation_1

Then, for each training sample and fitted model, I compute the corresponding test error using a large test sample generated from the same (known!) population. These are represented in the following plot together with their averages, which are shown using thicker lines3. The solid points represent the three models illustrated in the previous diagram.

# Compute the training and test errors for each model
pred <- list()
mse <- te <- matrix(NA, nrow = n_df, ncol = n_rep)
for (i in 1:n_rep) {
    mse[, i] <- sapply(res[[i]], function(obj) deviance(obj)/nobs(obj))
    pred[[i]] <- mapply(function(obj, degf) predict(obj, data.frame(x = xy_test$x)), 
        res[[i]], df)
    te[, i] <- sapply(as.list(data.frame(pred[[i]])), function(y_hat) mean((xy_test$y - 
        y_hat)^2))
}

# Compute the average training and test errors
av_mse <- rowMeans(mse)
av_te <- rowMeans(te)

# Plot the errors
plot(df, av_mse, type = "l", lwd = 2, col = gray(0.4), ylab = "Prediction error", 
    xlab = "Flexibilty (spline's degrees of freedom [log scaled])", ylim = c(0, 
        1), log = "x")
abline(h = sigma_eps, lty = 2, lwd = 0.5)
for (i in 1:n_rep) {
    lines(df, te[, i], col = "lightpink")
}
for (i in 1:n_rep) {
    lines(df, mse[, i], col = gray(0.8))
}
lines(df, av_mse, lwd = 2, col = gray(0.4))
lines(df, av_te, lwd = 2, col = "darkred")
points(df[1], av_mse[1], col = "palegreen3", pch = 17, cex = 1.5)
points(df[1], av_te[1], col = "palegreen3", pch = 17, cex = 1.5)
points(df[which.min(av_te)], av_mse[which.min(av_te)], col = "darkorange", pch = 16, 
    cex = 1.5)
points(df[which.min(av_te)], av_te[which.min(av_te)], col = "darkorange", pch = 16, 
    cex = 1.5)
points(df[25], av_mse[25], col = "steelblue", pch = 15, cex = 1.5)
points(df[25], av_te[25], col = "steelblue", pch = 15, cex = 1.5)
legend(x = "top", legend = c("Training error", "Test error"), lwd = rep(2, 2), 
    col = c(gray(0.4), "darkred"), text.width = 0.3, cex = 0.85)

cross-validation_2

One can see that the training errors decrease monotonically as the model gets more complicated (and less smooth). On the other side, even if the test error initially decreases, from a certain flexibility level on it starts increasing again. The change point occurs in correspondence of the orange model, that is, the model that provides a good compromise between bias and variance. The reason why the test error starts increasing for degrees of freedom larger than 3 or 4 is the so called overfitting problem. Overfitting is the tendency of a model to adapt too well to the training data, at the expense of generalization to previously unseen data points. In other words, an overfitted model fits the noise in the data rather than the actual underlying relationships among the variables. Overfitting usually occurs when a model is unnecessarily complex.

It is possible to show that the (expected) test error for a given observation in the test set can be decomposed into the sum of three components, namely

Expected Test Error=Irreducible Noise+(Model Bias)2+Model Variance,">Expected Test Error Irreducible Noise (Model Bias)^Model Variance
which is known as the bias-variance decomposition. The first term is the data generating process variance. This term is unavoidable because we live in a noisy stochastic world, where even the best ideal model has non-zero error. The second term originates from the difficulty to catch the correct functional form of the relationship that links the dependent and independent variables (sometimes it is also called the approximation bias). The last term is due to the fact that we estimate our models using only a limited amount of data. Fortunately, this terms gets closer and closer to zero as long as we collect more and more training data. Typically, the more complex (i.e., flexible) we make the model, the lower the bias but the higher the variance. This general phenomenon is known as thebias-variance trade-off, and the challenge is to find a model which provides a good compromise between these two issues.

Clearly, the situation illustrated above is only ideal, because in practice:

  • We do not know the true model that generates the data. Indeed, our models are typically more or less mis-specified.
  • We do only have a limited amount of data.

One way to overcome these hurdles and approximate the search for the optimal model is to use the cross-validation approach.

A Solution: Cross-Validation

In essence, all these ideas bring us to the conclusion that it is not advisable to compare the predictive accuracy of a set of models using the same observations used for estimating the models. Therefore, for assessing the models’ predictive performance we should use an independent set of data (the test sample). Then, the model showing the lowest error on the test sample (i.e., the lowest test error) is identified as the best.

Unfortunately, in many cases it is not possible to draw a (possibly large) independent set of observations for testing the models’ performance, because collecting data is typically an expensive activity. The immediate reaction to this statement is that we can solve this issue by splitting the available data in two sets, one of which will be used for training while the other is used for testing. The split is usually performed randomly to guarantee that the two parts have the same distribution4.

Even if data splitting provides an unbiased estimate of the test error, it is often quite noisy. A possible solution5 is to use cross-validation (CV). In its basic version, the so called k">kk-fold cross-validation, the samples are randomly partitioned into k">kk sets (called folds) of roughly equal size. A model is fit using all the samples except the first subset. Then, the prediction error of the fitted model is calculated using the first held-out samples. The same operation is repeated for each fold and the model’s performance is calculated by averaging the errors across the different test sets. k">kk is usually fixed at 5 or 10 . Cross-validation provides an estimate of the test error for each model6. Cross-validation is one of the most widely-used method for model selection, and for choosing tuning parameter values.

The code below illustrates k">kk-fold cross-validation using the same simulated data as above but not pretending to know the data generating process. In particular, I generate 100 observations and choose k=10">k=10k=10. Together with the training error curve, in the plot I report both the CV and test error curves. Additionally, I provide also the standard error bars, which are the standard errors of the individual prediction error for each of the k=10">k=10k=10 parts.

set.seed(seed)

n_train <- 100
xy <- gen_data(n_train, beta, sigma_eps)
x <- xy$x
y <- xy$y

fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf)))
mse <- sapply(fitted_models, function(obj) deviance(obj)/nobs(obj))

n_test <- 10000
xy_test <- gen_data(n_test, beta, sigma_eps)
pred <- mapply(function(obj, degf) predict(obj, data.frame(x = xy_test$x)), 
    fitted_models, df)
te <- sapply(as.list(data.frame(pred)), function(y_hat) mean((xy_test$y - y_hat)^2))

n_folds <- 10
folds_i <- sample(rep(1:n_folds, length.out = n_train))
cv_tmp <- matrix(NA, nrow = n_folds, ncol = length(df))
for (k in 1:n_folds) {
    test_i <- which(folds_i == k)
    train_xy <- xy[-test_i, ]
    test_xy <- xy[test_i, ]
    x <- train_xy$x
    y <- train_xy$y
    fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf)))
    x <- test_xy$x
    y <- test_xy$y
    pred <- mapply(function(obj, degf) predict(obj, data.frame(ns(x, df = degf))), 
        fitted_models, df)
    cv_tmp[k, ] <- sapply(as.list(data.frame(pred)), function(y_hat) mean((y - 
        y_hat)^2))
}
cv <- colMeans(cv_tmp)

require(Hmisc)

plot(df, mse, type = "l", lwd = 2, col = gray(0.4), ylab = "Prediction error", 
    xlab = "Flexibilty (spline's degrees of freedom [log scaled])", main = paste0(n_folds, 
        "-fold Cross-Validation"), ylim = c(0.1, 0.8), log = "x")
lines(df, te, lwd = 2, col = "darkred", lty = 2)
cv_sd <- apply(cv_tmp, 2, sd)/sqrt(n_folds)
errbar(df, cv, cv + cv_sd, cv - cv_sd, add = TRUE, col = "steelblue2", pch = 19, 
    lwd = 0.5)
lines(df, cv, lwd = 2, col = "steelblue2")
points(df, cv, col = "steelblue2", pch = 19)
legend(x = "topright", legend = c("Training error", "Test error", "Cross-validation error"), 
    lty = c(1, 2, 1), lwd = rep(2, 3), col = c(gray(0.4), "darkred", "steelblue2"), 
    text.width = 0.4, cex = 0.85)

cross-validation_3

Often a “one-standard error” rule is used with cross-validation, according to which one should choose the most parsimonious model whose error is no more than one standard error above the error of the best model. In the example above, the best model (that for which the CV error is minimized) uses 3 degrees of freedom, which also satisfies the requirement of the one-standard error rule.

The case where k=n">k=nk=n corresponds to the so called leave-one-out cross-validation (LOOCV) method. In this case the test set contains a single observation. The advantages of LOOCV are: 1) it doesn’t require random numbers to select the observations to test, meaning that it doesn’t produce different results when applied repeatedly, and 2) it has far less bias than k">kk-fold CV because it employs larger training sets containing n1">n1n−1 observations each. On the other side, LOOCV presents also some drawbacks: 1) it is potentially quite intense computationally, and 2) due to the fact that any two training sets share n2">n2n−2 points, the models fit to those training sets tend to be strongly correlated with each other.

The code below implements LOOCV using the same example I discussed so far. The next plot shows that most of the times LOOCV does not provide dramatically different results with respect to CV.

require(splines)

loocv_tmp <- matrix(NA, nrow = n_train, ncol = length(df))
for (k in 1:n_train) {
  train_xy <- xy[-k, ]
  test_xy <- xy[k, ]
  x <- train_xy$x
  y <- train_xy$y
  fitted_models <- apply(t(df), 2, function(degf) lm(y ~ ns(x, df = degf)))
  pred <- mapply(function(obj, degf) predict(obj, data.frame(x = test_xy$x)),
                 fitted_models, df)
  loocv_tmp[k, ] <- (test_xy$y - pred)^2
}
loocv <- colMeans(loocv_tmp)

plot(df, mse, type = "l", lwd = 2, col = gray(.4), ylab = "Prediction error",
     xlab = "Flexibilty (spline's degrees of freedom [log scaled])",
     main = "Leave-One-Out Cross-Validation", ylim = c(.1, .8), log = "x")
lines(df, cv, lwd = 2, col = "steelblue2", lty = 2)
lines(df, loocv, lwd = 2, col = "darkorange")
legend(x = "topright", legend = c("Training error", "10-fold CV error", "LOOCV error"),
       lty = c(1, 2, 1), lwd = rep(2, 3), col = c(gray(.4), "steelblue2", "darkorange"),
       text.width = .3, cex = .85)

cross-validation_4

Doing Cross-Validation With R: the caret Package

There are many R packages that provide functions for performing different flavors of CV. In my opinion, one of the best implementation of these ideas is available in the caret package by Max Kuhn (see Kuhn and Johnson 2013)7. The aim of the caret package (acronym of classification and regression training) is to provide a very general and efficient suite of commands for building and assessing predictive models. It allows to compare the predictive accuracy of a multitude of models (currently more than 200), including the most recent ones from machine learning. The comparison of different models can be done using cross-validation as well as with other approaches. The package also provides many options for data pre-processing. It is not my aim to provide here a thorough presentation of all the package features. Rather, I will focus only on a handful of its functions, those that allow to perform CV. For more details on the other package functions, you can inspect the package documentation and its website. To illustrate these feature I will use some data for a credit scoring application whose data can be found here.

Since credit scoring is a classification problem, I will use the number of misclassified observations as the loss measure. The data set contains information about 4,455 individuals for the following variables:

Variable Description
Status credit status
Seniority job seniority (years)
Home type of home ownership
Time time of requested loan
Age client’s age
Marital marital status
Records existence of records
Job type of job
Expenses amount of expenses
Income amount of income
Assets amount of assets
Debt amount of debt
Amount amount requested of loan
Price price of good

Here I use the “cleaned” version of the data set, where some pre-processing has already been performed (i.e., removal of few observations, imputation of missing values and categorization of continuous predictors). The tidy data are contained in the file CleanCreditScoring.csv.

require(RCurl)
require(prettyR)

url <- "https://raw.githubusercontent.com/gastonstat/CreditScoring/master/CleanCreditScoring.csv"
cs_data <- getURL(url)
cs_data <- read.csv(textConnection(cs_data))
describe(cs_data)

## Description of cs_data

## 
##  Numeric 
##              mean  median          var       sd valid.n
## Seniority    7.99    5.00        66.85     8.18    4446
## Time        46.45   48.00       214.56    14.65    4446
## Age         37.08   36.00       120.70    10.99    4446
## Expenses    55.60   51.00       381.06    19.52    4446
## Income     140.63  124.00      6428.50    80.18    4446
## Assets    5354.95 3000.00 133040726.62 11534.33    4446
## Debt       342.26    0.00   1549264.52  1244.69    4446
## Amount    1038.76 1000.00    225385.62   474.75    4446
## Price     1462.48 1400.00    395081.60   628.56    4446
## Finrat      72.62   77.10       415.78    20.39    4446
## Savings      3.86    3.12        13.89     3.73    4446
## 
##  Factor 
##          
## Status       good     bad
##   Count   3197.00 1249.00
##   Percent   71.91   28.09
## Mode good 
##          
## Home        owner   rent parents  other   priv ignore
##   Count   2106.00 973.00  782.00 319.00 246.00  20.00
##   Percent   47.37  21.88   17.59   7.17   5.53   0.45
## Mode owner 
##          
## Marital   married single separated widow divorced
##   Count   3238.00 973.00    130.00 67.00    38.00
##   Percent   72.83  21.88      2.92  1.51     0.85
## Mode married 
##          
## Records   no_rec yes_rec
##   Count   3677.0   769.0
##   Percent   82.7    17.3
## Mode no_rec 
##          
## Job         fixed freelance partime others
##   Count   2803.00   1021.00  451.00 171.00
##   Percent   63.05     22.96   10.14   3.85
## Mode fixed 
##           
## seniorityR sen (-1,1] sen (3,8] sen (14,99] sen (1,3] sen (8,14]
##    Count      1042.00       978      880.00    789.00     757.00
##    Percent      23.44        22       19.79     17.75      17.03
## Mode sen (-1,1] 
##          
## timeR     time (48,99] time (24,36] time (36,48] time (12,24] time (0,12]
##   Count        1949.00       991.00       885.00       441.00      180.00
##   Percent        43.84        22.29        19.91         9.92        4.05
## Mode time (48,99] 
##          
## ageR      age (30,40] age (40,50] age (25,30] age (0,25] age (50,99]
##   Count       1415.00      900.00      781.00     699.00      651.00
##   Percent       31.83       20.24       17.57      15.72       14.64
## Mode age (30,40] 
##          
## expensesR exp (0,40] exp (40,50] exp (50,60] exp (60,80] exp (80,1e+04]
##   Count      1219.00      999.00      979.00      798.00         451.00
##   Percent      27.42       22.47       22.02       17.95          10.14
## Mode exp (0,40] 
##          
## incomeR   inc (80,110] inc (140,190] inc (0,80] inc (110,140]
##   Count         954.00        915.00     886.00        866.00
##   Percent        21.46         20.58      19.93         19.48
##          
## incomeR   inc (190,1e+04]
##   Count            825.00
##   Percent           18.56
## Mode inc (80,110] 
##          
## assetsR   asset (-1,0] asset (3e+03,5e+03] asset (8e+03,1e+06]
##   Count        1626.00              937.00              719.00
##   Percent        36.57               21.08               16.17
##          
## assetsR   asset (0,3e+03] asset (5e+03,8e+03]
##   Count            626.00               538.0
##   Percent           14.08                12.1
## Mode asset (-1,0] 
##          
## debtR     debt (-1,0] debt (500,1.5e+03] debt (2.5e+03,1e+06] debt (0,500]
##   Count       3667.00             230.00               197.00       193.00
##   Percent       82.48               5.17                 4.43         4.34
##          
## debtR     debt (1.5e+03,2.5e+03]
##   Count                   159.00
##   Percent                   3.58
## Mode debt (-1,0] 
##          
## amountR   am (900,1.1e+03] am (1.1e+03,1.4e+03] am (600,900] am (0,600]
##   Count             945.00               925.00       911.00     895.00
##   Percent            21.26                20.81        20.49      20.13
##          
## amountR   am (1.4e+03,1e+05]
##   Count               770.00
##   Percent              17.32
## Mode am (900,1.1e+03] 
##          
## priceR    priz (1.5e+03,1.8e+03] priz (1e+03,1.3e+03] priz (0,1e+03]
##   Count                  1028.00               985.00         821.00
##   Percent                  23.12                22.15          18.47
##          
## priceR    priz (1.8e+03,1e+05] priz (1.3e+03,1.5e+03]
##   Count                 811.00                 801.00
##   Percent                18.24                  18.02
## Mode priz (1.5e+03,1.8e+03] 
##          
## finratR   finr (80,90] finr (90,100] finr (50,70] finr (70,80] finr (0,50]
##   Count         995.00        960.00       954.00       821.00       716.0
##   Percent        22.38         21.59        21.46        18.47        16.1
## Mode finr (80,90] 
##          
## savingsR  sav (2,4] sav (0,2] sav (6,99] sav (4,6] sav (-99,0]
##   Count      1396.0   1111.00      827.0    814.00       298.0
##   Percent      31.4     24.99       18.6     18.31         6.7
## Mode sav (2,4]

The caret package provides functions for splitting the data as well as functions that automatically do all the job for us, namely functions that create the resampled data sets, fit the models, and evaluate performance.

Among the functions for data splitting I just mention createDataPartition() and createFolds(). The former allows to create one or more test/training random partitions of the data, while the latter randomly splits the data into k">kk subsets. In both functions the random sampling is done within the levels of y">yy (when y">yy is categorical) to balance the class distributions within the splits. These functions return vectors of indexes that can then be used to subset the original sample into training and test sets.

require(caret)

classes <- cs_data[, "Status"]
predictors <- cs_data[, -match(c("Status", "Seniority", "Time", "Age", "Expenses", 
    "Income", "Assets", "Debt", "Amount", "Price", "Finrat", "Savings"), colnames(cs_data))]

train_set <- createDataPartition(classes, p = 0.8, list = FALSE)
str(train_set)

##  int [1:3558, 1] 1 2 3 4 5 6 7 8 9 11 ...
##  - attr(*, "dimnames")=List of 2
##   ..$ : NULL
##   ..$ : chr "Resample1"

train_predictors <- predictors[train_set, ]
train_classes <- classes[train_set]
test_predictors <- predictors[-train_set, ]
test_classes <- classes[-train_set]

set.seed(seed)
cv_splits <- createFolds(classes, k = 10, returnTrain = TRUE)
str(cv_splits)

## List of 10
##  $ Fold01: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ...
##  $ Fold02: int [1:4002] 2 3 4 5 6 7 8 9 10 11 ...
##  $ Fold03: int [1:4001] 1 2 3 4 5 6 7 8 9 10 ...
##  $ Fold04: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ...
##  $ Fold05: int [1:4001] 1 2 3 4 7 8 10 11 12 13 ...
##  $ Fold06: int [1:4001] 1 2 3 4 5 6 7 8 9 10 ...
##  $ Fold07: int [1:4001] 1 2 4 5 6 7 9 10 11 12 ...
##  $ Fold08: int [1:4002] 1 2 3 4 5 6 7 8 9 10 ...
##  $ Fold09: int [1:4001] 1 2 3 5 6 8 9 11 12 13 ...
##  $ Fold10: int [1:4001] 1 3 4 5 6 7 8 9 10 11 ...

To automatically split the data, fit the models and assess the performance, one can use the train() function in the caret package. The code below shows an example of the train() function on the credit scoring data by modeling the outcome using all the predictors available with a penalized logistic regression. More specifically, I use the glmnet package (Friedman, Hastie, and Tibshirani 2008), that fits a generalized linear model via penalized maximum likelihood. The algorithm implemented in the package computes the regularization path for the elastic-net penalty over a grid of values for the regularization parameter λ">λλ. The tuning parameter λ">λλ controls the overall strength of the penalty. A second tuning parameter, called the mixing percentage and denoted with α">αα, represents the elastic-net penalty (Zou and Hastie 2005). This parameter takes value in [0,1]">[0,1][0,1] and bridges the gap between the lasso (α=1">α=1α=1) and the ridge (α=0">α=0α=0) approaches.

The train() function requires the model formula together with the indication of the model to fit and the grid of tuning parameter values to use. In the code below this grid is specified through the tuneGrid argument, while trControl provides the method to use for choosing the optimal values of the tuning parameters (in our case, 10-fold cross-validation). Finally, the preProcess argument allows to apply a series of pre-processing operations on the predictors (in our case, centering and scaling the predictor values).

require(glmnet)

## Warning: package 'Matrix' was built under R version 3.2.5

set.seed(seed)

cs_data_train <- cs_data[train_set, ]
cs_data_test <- cs_data[-train_set, ]

glmnet_grid <- expand.grid(alpha = c(0,  .1,  .2, .4, .6, .8, 1),
                           lambda = seq(.01, .2, length = 20))
glmnet_ctrl <- trainControl(method = "cv", number = 10)
glmnet_fit <- train(Status ~ ., data = cs_data_train,
                    method = "glmnet",
                    preProcess = c("center", "scale"),
                    tuneGrid = glmnet_grid,
                    trControl = glmnet_ctrl)
glmnet_fit

## glmnet 
## 
## 3558 samples
##   26 predictor
##    2 classes: 'bad', 'good' 
## 
## Pre-processing: centered (68), scaled (68) 
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 3202, 3203, 3202, 3203, 3202, 3202, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda  Accuracy   Kappa       
##   0.0    0.01    0.8021427  0.4613907413
##   0.0    0.02    0.7998916  0.4520486081
##   0.0    0.03    0.7976412  0.4402614685
##   0.0    0.04    0.7987633  0.4407093800
##   0.0    0.05    0.7982015  0.4355350784
##   0.0    0.06    0.7979182  0.4313111542
##   0.0    0.07    0.7953893  0.4205306747
##   0.0    0.08    0.7931413  0.4105376360
##   0.0    0.09    0.7922978  0.4050557210
##   0.0    0.10    0.7892072  0.3920192662
##   0.0    0.11    0.7841454  0.3722554927
##   0.0    0.12    0.7824600  0.3640031420
##   0.0    0.13    0.7807739  0.3557226473
##   0.0    0.14    0.7793694  0.3482341774
##   0.0    0.15    0.7807746  0.3491274474
##   0.0    0.16    0.7810571  0.3472621824
##   0.0    0.17    0.7796511  0.3411817028
##   0.0    0.18    0.7796526  0.3373484610
##   0.0    0.19    0.7807746  0.3374411775
##   0.0    0.20    0.7785267  0.3288713594
##   0.1    0.01    0.8015794  0.4596697325
##   0.1    0.02    0.8010160  0.4506840219
##   0.1    0.03    0.7987672  0.4392696923
##   0.1    0.04    0.7962367  0.4270137579
##   0.1    0.05    0.7953909  0.4179208158
##   0.1    0.06    0.7922986  0.4039787586
##   0.1    0.07    0.7892056  0.3879423318
##   0.1    0.08    0.7880804  0.3782808312
##   0.1    0.09    0.7841454  0.3596294655
##   0.1    0.10    0.7807723  0.3432057135
##   0.1    0.11    0.7748718  0.3176779656
##   0.1    0.12    0.7726223  0.3058039618
##   0.1    0.13    0.7706544  0.2940236563
##   0.1    0.14    0.7672804  0.2743480275
##   0.1    0.15    0.7641905  0.2598076390
##   0.1    0.16    0.7613815  0.2443603691
##   0.1    0.17    0.7585718  0.2288614463
##   0.1    0.18    0.7549153  0.2107358484
##   0.1    0.19    0.7518231  0.1940038500
##   0.1    0.20    0.7495751  0.1809109004
##   0.2    0.01    0.8021396  0.4602812440
##   0.2    0.02    0.7982070  0.4415693991
##   0.2    0.03    0.7951147  0.4239150748
##   0.2    0.04    0.7917424  0.4066236599
##   0.2    0.05    0.7911742  0.3954568138
##   0.2    0.06    0.7875218  0.3766126319
##   0.2    0.07    0.7813388  0.3485492065
##   0.2    0.08    0.7765588  0.3236774842
##   0.2    0.09    0.7723398  0.3024505090
##   0.2    0.10    0.7706583  0.2883765872
##   0.2    0.11    0.7664425  0.2668721852
##   0.2    0.12    0.7627868  0.2474032680
##   0.2    0.13    0.7554779  0.2144108884
##   0.2    0.14    0.7509796  0.1882593847
##   0.2    0.15    0.7445165  0.1570413702
##   0.2    0.16    0.7419869  0.1374198064
##   0.2    0.17    0.7372108  0.1088243186
##   0.2    0.18    0.7346811  0.0884942046
##   0.2    0.19    0.7304645  0.0666304068
##   0.2    0.20    0.7293401  0.0587304848
##   0.4    0.01    0.7998916  0.4500521540
##   0.4    0.02    0.7951147  0.4262079087
##   0.4    0.03    0.7894920  0.3986179726
##   0.4    0.04    0.7833091  0.3669902064
##   0.4    0.05    0.7793694  0.3406658720
##   0.4    0.06    0.7765596  0.3202875649
##   0.4    0.07    0.7726246  0.2966774741
##   0.4    0.08    0.7639128  0.2564191601
##   0.4    0.09    0.7546352  0.2079469868
##   0.4    0.10    0.7456401  0.1554338724
##   0.4    0.11    0.7369299  0.1037171598
##   0.4    0.12    0.7324316  0.0745712422
##   0.4    0.13    0.7296218  0.0585154515
##   0.4    0.14    0.7282173  0.0515741506
##   0.4    0.15    0.7270937  0.0460444276
##   0.4    0.16    0.7237213  0.0275193624
##   0.4    0.17    0.7203474  0.0079998432
##   0.4    0.18    0.7189429  0.0000000000
##   0.4    0.19    0.7189429  0.0000000000
##   0.4    0.20    0.7189429  0.0000000000
##   0.6    0.01    0.7998940  0.4479646834
##   0.6    0.02    0.7900538  0.4056859000
##   0.6    0.03    0.7830274  0.3693670539
##   0.6    0.04    0.7765572  0.3347009487
##   0.6    0.05    0.7757161  0.3177186986
##   0.6    0.06    0.7655990  0.2676239096
##   0.6    0.07    0.7549153  0.2082278135
##   0.6    0.08    0.7422693  0.1330653578
##   0.6    0.09    0.7332758  0.0786721840
##   0.6    0.10    0.7296218  0.0585273347
##   0.6    0.11    0.7279364  0.0501825990
##   0.6    0.12    0.7223176  0.0195861846
##   0.6    0.13    0.7189429  0.0000000000
##   0.6    0.14    0.7189429  0.0000000000
##   0.6    0.15    0.7189429  0.0000000000
##   0.6    0.16    0.7189429  0.0000000000
##   0.6    0.17    0.7189429  0.0000000000
##   0.6    0.18    0.7189429  0.0000000000
##   0.6    0.19    0.7189429  0.0000000000
##   0.6    0.20    0.7189429  0.0000000000
##   0.8    0.01    0.7959582  0.4342802453
##   0.8    0.02    0.7869623  0.3901784671
##   0.8    0.03    0.7776832  0.3436570802
##   0.8    0.04    0.7745925  0.3177549651
##   0.8    0.05    0.7650348  0.2651758625
##   0.8    0.06    0.7518270  0.1887158977
##   0.8    0.07    0.7363681  0.0981410803
##   0.8    0.08    0.7299027  0.0607058454
##   0.8    0.09    0.7270937  0.0452185157
##   0.8    0.10    0.7189429  0.0008725808
##   0.8    0.11    0.7189429  0.0000000000
##   0.8    0.12    0.7189429  0.0000000000
##   0.8    0.13    0.7189429  0.0000000000
##   0.8    0.14    0.7189429  0.0000000000
##   0.8    0.15    0.7189429  0.0000000000
##   0.8    0.16    0.7189429  0.0000000000
##   0.8    0.17    0.7189429  0.0000000000
##   0.8    0.18    0.7189429  0.0000000000
##   0.8    0.19    0.7189429  0.0000000000
##   0.8    0.20    0.7189429  0.0000000000
##   1.0    0.01    0.7920241  0.4209892004
##   1.0    0.02    0.7807794  0.3655088353
##   1.0    0.03    0.7754360  0.3289251638
##   1.0    0.04    0.7686897  0.2864096482
##   1.0    0.05    0.7526713  0.1943914570
##   1.0    0.06    0.7349620  0.0898775655
##   1.0    0.07    0.7290600  0.0557404582
##   1.0    0.08    0.7209100  0.0108497047
##   1.0    0.09    0.7189429  0.0000000000
##   1.0    0.10    0.7189429  0.0000000000
##   1.0    0.11    0.7189429  0.0000000000
##   1.0    0.12    0.7189429  0.0000000000
##   1.0    0.13    0.7189429  0.0000000000
##   1.0    0.14    0.7189429  0.0000000000
##   1.0    0.15    0.7189429  0.0000000000
##   1.0    0.16    0.7189429  0.0000000000
##   1.0    0.17    0.7189429  0.0000000000
##   1.0    0.18    0.7189429  0.0000000000
##   1.0    0.19    0.7189429  0.0000000000
##   1.0    0.20    0.7189429  0.0000000000
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final values used for the model were alpha = 0 and lambda = 0.01.

trellis.par.set(caretTheme())
plot(glmnet_fit, scales = list(x = list(log = 2)))

cross-validation_5

The previous plot shows the “accuracy”, that is the percentage of correctly classified observations, for the penalized logistic regression model with each combination of the two tuning parameters α">ααand λ">λλ. The optimal tuning parameter values are α=">α=α= 0 and λ=">λ=λ= 0.01.

Then, it is possible to predict new samples with the identified optimal model using the predict method:

pred_classes <- predict(glmnet_fit, newdata = cs_data_test)
table(pred_classes)

## pred_classes
##  bad good 
##  172  716

pred_probs <- predict(glmnet_fit, newdata = cs_data_test, type = "prob")
head(pred_probs)

##          bad      good
## 1 0.07142151 0.9285785
## 2 0.04231067 0.9576893
## 3 0.03736701 0.9626330
## 4 0.14796622 0.8520338
## 5 0.12416939 0.8758306
## 6 0.42359516 0.5764048

If you need to deepen your knowledge of predictive analytics, you may find something interesting in the R Course Data Mining with R.

Stay tuned for the next article on the MilanoR blog!

References

Efron, B., and R. Tibshirani. 1993. An Introduction to the Bootstrap. CRC Press.

Friedman, J., T. Hastie, and R. Tibshirani. 2008. “Regularization Paths for Generalized Linear Models via Coordinate Descent.” Journal of Statistical Software 33 (1): 1–22.

Hastie, T., R. Tibshirani, and J. Friedman. 2009. The Elements of Statistical Learning. 2nd ed. Springer.

James, G., D. Witten, T. Hastie, and R. Tibshirani. 2013. An Introduction to Statistical Learning. Springer.

Kuhn, M., and K. Johnson. 2013. Applied Predictive Modeling. Springer.

Zou, H., and T. Hastie. 2005. “Regularization and Variable Selection via the Elastic Net.” Journal of the Royal Statistical Association B 67 (2): 301–20.


  1. By the way, it seems that the oracular powers appeared to be associated with hallucinogenic gases that puffed out from the temple floor.
  2. You can find a thorough formal illustration of all these concepts in Hastie, Tibshirani, and Friedman (2009), Chapter 7. A somewhat simpler presentation can be found in James et al. (2013).
  3. More precisely, the light red curves correspond to what is called conditional test error, which means that each curve is conditional on the corresponding training sample. The heavier red curve correspond to the expected test error. In general, we would like to focus on the conditional test error for the particular training sample we have. However, this curve is very difficult to be estimated and in practice the expected test error is typically targeted. As we will see, cross-validation is a method for estimating the expected test error. For more details see Hastie, Tibshirani, and Friedman (2009).
  4. A variant of the purely random split is to use stratified random sampling in order to create subsets that are balanced with respect to the outcome. This is useful in particular in classification problems when one class has a disproportionately small frequency compared to the others.
  5. An alternative approach for the same objective is the bootstrap, that won’t be illustrated here (see Efron and Tibshirani (1993)).
  6. More precisely, cross-validation provides an estimate of the expected test error.
  7. The boot package contains also a nice function called cv.glm, which implements k">kk-fold cross-validation for generalized linear models.

The post Cross-Validation for Predictive Analytics Using R appeared first on MilanoR.

To leave a comment for the author, please follow the link and comment on their blog: MilanoR.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)