Evaluating model performance – A practical example of the effects of overfitting and data size on prediction

[This article was first published on me nugget, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.


Following my last post on decision making trees and machine learning, where I presented some tips gathered from the “Pragmatic Programming Techniques” blog, I have again been impressed by its clear presentation of strategies regarding the evaluation of model performance. I have seen some of these topics presented elsewhere – especially graphics showing the link between model complexity and prediction error (i.e. “overfitting“) – but this particular presentation made me want to go back to this topic and try to make a practical example in R that I could use when teaching.

Effect of overfitting on prediction
The above graph shows polynomial fitting of various degrees to an artificial data set – The “real” underlying model is a 3rd-degree polynomial (y ~ b3*x^3 + b2*x^2 + b1*x + a). One gets a good idea that the higher degree models are incorrect give the single-term removal significance tests provided by the summary function (e.g. 5th-degree polynomial model):

Call:
lm(formula = ye ~ poly(x, degree = 5), data = df)
Residuals:
Min 1Q Median 3Q Max
-4.4916 -2.0382 -0.4417 2.2340 8.1518
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 29.3696 0.4304 68.242 < 2e-16 ***
poly(x, degree = 5)1 74.4980 3.0432 24.480 < 2e-16 ***
poly(x, degree = 5)2 54.0712 3.0432 17.768 < 2e-16 ***
poly(x, degree = 5)3 23.5394 3.0432 7.735 9.72e-10 ***
poly(x, degree = 5)4 -3.0043 3.0432 -0.987 0.329
poly(x, degree = 5)5 1.1392 3.0432 0.374 0.710

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 3.043 on 44 degrees of freedom
Multiple R-squared: 0.9569, Adjusted R-squared: 0.952
F-statistic: 195.2 on 5 and 44 DF, p-value: < 2.2e-16


Nevertheless, a more robust analysis of prediction error is through a cross-validation – by splitting the data into training and validation sub-sets. The following example does this split at 50% training and 50% validation, with 500 permutations.


So, here we have the typical trend of increasing prediction error with model complexity (via cross-validation – CV) when the model is overfit (i.e. > 3rd-degree polynomial, vertical grey dashed line). As reference, the horizontal grey dashed line shows the original amount of error added, which is where the CV error reaches a minimum.

Effect of data size on prediction
Another interesting aspect presented in the post is the use of CV in estimating the relationship between prediction error and the amount of data used in the model fitting (credit given to Andrew Ng from Stanford). This is helpful concept when determining what the benefit in prediction would be following an invest in more data sampling:


Here we see that, given a fixed model complexity, training error and CV error converges. Again, the horizontal grey dashed line indicates the actual measurement error of the response variable. So, in this example, there is not much improvement in prediction error following a data size of ca. 100. Interestingly, the example also demonstrates that even with an overfit model containing a 7th-degree polynomial, the increased prediction error is overcome with a larger data set. For comparison, the same exercise done with the correct 3rd-degree model shows that even the smaller data set achieves a relatively low prediction error even when the data size is small (2.6 MAE in 3rd-degree poly. vs 3.7 MAE in 7th-degree poly.):


Code to reproduce example:

### Data and model fitting
set.seed(1111)
n <- 50
x <- sort(runif(n, -2, 2))
y <- 3*x^3 + 5*x^2 + 0.5*x + 20 # a 3 polynomial model
err <- rnorm(n, sd=3)
ye <- y + err
df <- data.frame(x, ye)
nterm <- c(1,2,3,5,7,9)
 
png("model_fit~terms.png", width=5, height=4, units="in", res=400, type="cairo")
par(mar=c(4,4,1,1))
plot(ye~x, df, ylab="y")
PAL <- colorRampPalette(c("blue", "cyan", "yellow", "red"))
COLS <- PAL(length(nterm))
for(i in seq(nterm)){
 fit <- lm(ye ~ poly(x, degree=nterm[i]), data=df)
 newdat <- data.frame(x=seq(min(df$x), max(df$x),,100))
 lines(newdat$x, predict(fit, newdat), col=COLS[i])
}
legend("topleft", legend=paste0(nterm, c("", "", "*", "", "", "")), title="polynomial degrees", bty="n", col=COLS, lty=1)
dev.off()
 
 
### Term significance
fit <- lm(ye ~ poly(x, degree=5), data=df)
summary(fit)
 
 
### Error as a function of model complexity
set.seed(1111)
n <- 50
nterm <- seq(12)
perms <- 500
frac.train <- 0.5 #training fraction of data
run <- data.frame(n, nterm, train.err=NaN, cv.err=NaN)
run
x <- sort(runif(n, -2, 2))
y <- 3*x^3 + 5*x^2 + 0.5*x + 20 # a 3 polynomial model
err <- rnorm(n, sd=3)
ye <- y + err
df <- data.frame(x, ye)
 
for(i in seq(nrow(run))){
 pred.train <- matrix(NaN, nrow=nrow(df), ncol=perms)
 pred.valid <- matrix(NaN, nrow=nrow(df), ncol=perms)
 for(j in seq(perms)){
  train <- sample(nrow(df), nrow(df)*frac.train)
  valid <- seq(nrow(df))[-train]
  dftrain <- df[train,]
  dfvalid <- df[valid,]
  fit <- lm(ye ~ poly(x, degree=run$nterm[i]), data=dftrain)
  pred.train[train,j] <- predict(fit)
  pred.valid[valid,j] <- predict(fit, dfvalid)
 }
 run$train.err[i] <- mean(abs(df$ye - pred.train), na.rm=TRUE) # sqrt(mean((df$ye - pred.train)^2, na.rm=TRUE))
 run$cv.err[i] <- mean(abs(df$ye - pred.valid), na.rm=TRUE) # sqrt(mean((df$ye - pred.valid)^2, na.rm=TRUE))
 print(i)
}
 
png("error~complexity.png", width=5, height=4, units="in", res=400, type="cairo")
par(mar=c(4,4,1,1))
ylim <- range(run$train.err, run$cv.err)
plot(run$nterm, run$train.err, log="y", col=1, t="o", ylim=ylim, xlab="Model complexity [polynomial degrees]", ylab="Mean absolute error [MAE]")
lines(run$nterm, run$cv.err, col=2, t="o")
abline(v=3, lty=2, col=8)
abline(h=mean(abs(err)), lty=2, col=8)
legend("top", legend=c("Training error", "Cross-validation error"), bty="n", col=1:2, lty=1, pch=1)
dev.off()
 
 
 
### Error as a function of data size
set.seed(1111)
n <- round(exp(seq(log(50), log(500),, 10)))
nterm <- 7
perms <- 500
frac.train <- 0.5 #training fraction of data
run <- data.frame(n, nterm, train.err=NaN, cv.err=NaN)
run
x <- sort(runif(max(n), -2, 2))
y <- 3*x^3 + 5*x^2 + 0.5*x + 20 # a 3 polynomial model
err <- rnorm(max(n), sd=3)
ye <- y + err
DF <- data.frame(x, ye)
 
for(i in seq(nrow(run))){
 df <- DF[1:run$n[i],]
 pred.train <- matrix(NaN, nrow=nrow(df), ncol=perms)
 pred.valid <- matrix(NaN, nrow=nrow(df), ncol=perms)
 for(j in seq(perms)){
  train <- sample(nrow(df), nrow(df)*frac.train)
  valid <- seq(nrow(df))[-train]
  dftrain <- df[train,]
  dfvalid <- df[valid,]
  fit <- lm(ye ~ poly(x, degree=run$nterm[i]), data=dftrain)
  pred.train[train,j] <- predict(fit)
  pred.valid[valid,j] <- predict(fit, dfvalid)
 }
 run$train.err[i] <- mean(abs(df$ye - pred.train), na.rm=TRUE) # sqrt(mean((df$ye - pred.train)^2, na.rm=TRUE))
 run$cv.err[i] <- mean(abs(df$ye - pred.valid), na.rm=TRUE) # sqrt(mean((df$ye - pred.valid)^2, na.rm=TRUE))
 print(i)
}
 
png(paste0("error~data_size_", paste0(nterm, "term"), ".png"), width=5, height=4, units="in", res=400, type="cairo")
par(mar=c(4,4,1,1))
ylim <- range(run$train.err, run$cv.err)
plot(run$n, run$train.err, log="xy", col=1, t="o", ylim=ylim, xlab="Data size [n]", ylab="Mean absolute error [MAE]")
lines(run$n, run$cv.err, col=2, t="o")
abline(h=mean(abs(err)), lty=2, col=8)
legend("bottomright", legend=paste0("No. of polynomial degrees = ", nterm), bty="n")
legend("top", legend=c("Training error", "Cross-validation error"), bty="n", col=1:2, lty=1, pch=1)
dev.off()


To leave a comment for the author, please follow the link and comment on their blog: me nugget.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)