Learning from Learning Curves
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
by Bob Horton, Senior Data Scientist, Microsoft
This is a follow-up to my earlier post on learning curves. A learning curve is a plot of predictive error for training and validation sets over a range of training set sizes. Here we’re using simulated data to explore some fundamental relationships between training set size, model complexity, and prediction error.
Start by simulating a dataset:
sim_data <- function(N, num_inputs=8, input_cardinality=10){ inputs <- rep(input_cardinality, num_inputs) names(inputs) <- paste0("X", seq_along(inputs)) as.data.frame(lapply (inputs, function(cardinality) sample(LETTERS[1:cardinality], N, replace=TRUE))) }
The input columns are named X1, X2, etc.; these are all categorical variables with single capital letters representing the different categories. Cardinality is the number of possible values in the column; our default cardinality of 10 means we sample from the capital letters A
through J
.
Next we’ll add an outcome variable (y
); it has a base level of 100, but if the values in the first two X
variables are equal, this is increased by 10. On top of this we add some normally distributed noise.
set.seed(123) data <- sim_data(3e4, input_cardinality=10) noise <- 2 data <- transform(data, y = ifelse(X1 == X2, 110, 100) + rnorm(nrow(data), sd=noise))
With linear models, we handle an interaction between two categorical variables by adding an interaction term; the number of possibilities in this interaction term is basically the product of the cardinalities. In this simulated data set, only the first two columns affect the outcome, and the other input columns don’t contain any useful information. We’ll use it to demonstrate how adding non-informative variables affects overfitting and training set size requirements.
As in the earlier post, I’ll use the root mean squared error of the predictions as the error function because RMSE is essentially the same as standard deviation. No model should be able to make predictions with a root mean squared error less than the standard deviation of the random noise we added.
rmse <- function(actual, predicted) sqrt( mean( (actual - predicted)^2 ))
The cross-validation function trains a model using the supplied formula and modeling function, then tests its performance on a held-out test set. The training set will be sampled from the data available for training; to use approximately a 10% sample of the training data, set prob_train
to 0.1
.
cross_validate <- function(model_formula, fit_function, error_function, validation_set, training_data, prob_train=1){ training_set <- training_data[runif(nrow(training_data)) < prob_train,] tss <- nrow(training_set) outcome_var <- as.character(model_formula[[2]]) fit <- fit_function( model_formula, training_set) training_error <- error_function(training_set[[outcome_var]], predict(fit, training_set)) validation_error <- error_function(validation_set[[outcome_var]], predict(fit, validation_set)) data.frame(tss=tss, formula=deparse(model_formula), training=training_error, validation=validation_error, stringsAsFactors=FALSE) }
Construct a family of formulas, then use expand_grid
to make a dataframe with all the combinations of formulas and sampling probabilities:
generate_formula <- function(num_inputs, degree=2, outcome="y"){ inputs <- paste0("X", 1:num_inputs) rhs <- paste0("(", paste(inputs, collapse=" + "), ") ^ ", degree) paste(outcome, rhs, sep=" ~ ") } formulae <- lapply(2:(ncol(data) - 1), generate_formula) prob <- 2^(seq(0, -6, by=-0.5)) parameter_table <- expand.grid(formula=formulae, sampling_probability=prob, stringsAsFactors=FALSE)
Separate the training and validation data:
validation_fraction <- 0.25 in_validation_set <- runif(nrow(data)) < validation_fraction vset <- data[in_validation_set,] tdata <- data[!in_validation_set,] run_param_row <- function(i){ param <- parameter_table[i,] cross_validate(formula(param$formula[[1]]), lm, rmse, vset, tdata, param$sampling_probability[[1]]) }
Now call the cross-validate function on each row of the parameter table. The foreach
package makes it easy to process these jobs in parallel:
library(foreach) library(doParallel) registerDoParallel() # automatically manages cluster learning_curve_results <- foreach(i=1:nrow(parameter_table)) %dopar% run_param_row(i) learning_curve_table <- data.table::rbindlist(learning_curve_results)
The rbindlist()
function from the data.table
package puts the results together into a single data frame; this is both cleaner and dramatically faster than the old do.call("rbind", ...)
approach (though we’re just combining a small number of rows, so speed is not an issue here).
Now plot the results. Since we’ll do another plot later, I’ll wrap the plotting code in a function to make it more reusable.
plot_learning_curve <- function(lct, title, base_error, plot_training_error=TRUE, ...){ library(dplyr) library(tidyr) library(ggplot2) lct_long <- lct %>% gather(type, error, -tss, -formula) lct_long$type <- relevel(lct_long$type, "validation") plot_me <- if (plot_training_error) lct_long else lct_long[lct_long$type=="validation",] ggplot(plot_me, aes(x=log10(tss), y=error, col=formula, linetype=type)) + ggtitle(title) + geom_hline(yintercept=base_error, linetype=2) + geom_line(size=1) + xlab("log10(training set size)") + coord_cartesian(...) } plot_learning_curve(learning_curve_table, title="Extraneous variables are distracting", base_error=noise, ylim=c(0,4))
This illustrates the phenomenon that adding more inputs to a model increases the requirements for training data. This is true even if the extra inputs do not contain any information. The cases where the training error is zero are actually rank-deficient (like having fewer equations than unknowns), and if you try this at home you will get warnings to that effect; this is an extreme kind of overfitting. Other learning algorithms might handle this better than lm
, but the general idea is that those extra columns are distracting, and it takes more examples to falsify all the spurious correlations that get dredged up from those distractors.
But what if the additional columns considered by the more complex formulas actually did contain predictive information? Keeping the same X
-values, we can modify y
so that these other columns matter:
data <- transform(data, y = 100 + (X1==X2) * 10 + (X2==X3) * 3 + (X3==X4) * 3 + (X4==X5) * 3 + (X5==X6) * 3 + (X6==X7) * 3 + (X7==X8) * 3 + rnorm(nrow(data), sd=noise)) validation_fraction <- 0.25 in_validation_set <- runif(nrow(data)) < validation_fraction vset <- data[in_validation_set,] tdata <- data[!in_validation_set,] run_param_row <- function(i){ param <- parameter_table[i,] formula_string <- param$formula[[1]] prob <- param$sampling_probability[[1]] cross_validate(formula(formula_string), lm, rmse, vset, tdata, prob) } learning_curve_results <- foreach (i=1:nrow(parameter_table)) %dopar% run_param_row(i) lct <- data.table::rbindlist(learning_curve_results)
This time we’ll leave the training errors off the plot to focus on the validation error; this is what really matters when you are trying to generalize predictions.
plot_learning_curve(lct, title="Crossover Point", base_error=noise, plot_training_error=FALSE, ylim=c(1.5, 5))
Now we see another important phenomenon: The simple models that work best with small training sets are out-preformed by more complex models on larger training sets. But these complex models are only usable if they are given sufficient data; plotting a learning curve makes it clear whether you have used sufficient data or not.
Learning curves give valuable insights into the model training process. In some cases this can help you decide to expend effort or expense on gathering more data. In other cases you may discover that your models have learned all they can from just a fraction of the data that is already available. This might encourage you to investigate more complex models that may be capable of learning the finer details of the dataset, possibly leading to better predictions.
These curves can are computationally intensive, as is fitting even a single model on a large dataset in R. Parallelization helped here, but in a future post I’ll show similar patterns in learning curves for much bigger data sets (using real data, rather than synthetic) by taking advantage of the scalable tools of Microsoft R Server.
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.