Using caret to compare models

June 2, 2016

(This article was first published on Revolutions, and kindly contributed to R-bloggers)

by Joseph Rickert

The model table on the caret package website lists more that 200 variations of predictive analytics models that are available withing the caret framework. All of these models may be prepared, tuned, fit and evaluated with a common set of caret functions. All on its own, the table is an impressive testament to the utility and scope of the R language as data science tool. 

For the past year or so xgboost, the extreme gradient boosting algorithm, has been getting a lot of attention. The code below compares gbm with xgboost using the segmentationData set that comes with caret. The analysis presented here is far from the last word on comparing these models, but it does show how one might go about setting up a serious comparison using caret's functions to sweep through parameter space using parallel programming, and then used synchronized bootstrap samples to make a detailed comparison.

After reading in the data and dividing it into training and test data sets, caret's trainControl() and expand.grid() functions are used to set up to train the gbm model on all of the combinations of represented in the data frame built by expand.grid(). Then train() function does the actual training and fitting of the model. Notice that all of this happens in parallel. The Task Manager on my Windows 10 laptop shows all four cores maxed out at 100%.

After model fitting, predictions on the test data are computed and an ROC curve is drawn in the usual way. The AUC for gbm was computed to be 0.8731. Here is the ROC curve.


Next, a similar process for xgboost computes the AUC to be 0.8857, a fair improvement. The following plot shows how the ROC measure behaves with increasing tree depth for the two different values of the shrinkage parameter.



The final section of code shows how to caret can be used to compare the two models using the bootstrap samples that were created in the process of constructing the two models. The boxplots show xgboost has the edge although the gbm has a tighter distribution.


The next step, which I hope to take soon, is to rerun the analysis with more complete grids of tuning parameters. For a very accessible introduction to caret have a look at Max Kuhn's 2013 useR! tutorial.

### Packages Required
library(corrplot)			# plot correlations
library(doParallel)		# parallel processing
library(dplyr)        # Used by caret
library(gbm)				  # GBM Models
library(pROC)				  # plot the ROC curve
library(xgboost)      # Extreme Gradient Boosting
### Get the Data
# Load the data and construct indices to divied it into training and test data sets.
data(segmentationData)  	# Load the segmentation data set
trainIndex <- createDataPartition(segmentationData$Case,p=.5,list=FALSE)
trainData <- segmentationData[trainIndex,-c(1,2)]
testData  <- segmentationData[-trainIndex,-c(1,2)]
trainX <-trainData[,-1]        # Pull out the dependent variable
testX <- testData[,-1]
sapply(trainX,summary) # Look at a summary of the training data
# Set up training control
ctrl <- trainControl(method = "repeatedcv",   # 10fold cross validation
                     number = 5,							# do 5 repititions of cv
                     summaryFunction=twoClassSummary,	# Use AUC to pick the best model
                     allowParallel = TRUE)
# Use the expand.grid to specify the search space	
# Note that the default search grid selects multiple values of each tuning parameter
grid <- expand.grid(interaction.depth=c(1,2), # Depth of variable interactions
                    n.trees=c(10,20),	        # Num trees to fit
                    shrinkage=c(0.01,0.1),		# Try 2 values for learning rate 
                    n.minobsinnode = 20)
set.seed(1951)  # set the seed
# Set up to do parallel processing   
registerDoParallel(4)		# Registrer a parallel backend for train
gbm.tune <- train(x=trainX,y=trainData$Class,
                              method = "gbm",
                              metric = "ROC",
                              trControl = ctrl,
# Look at the tuning results
# Note that ROC was the performance criterion used to select the optimal model.   
plot(gbm.tune)  		# Plot the performance of the training models
res <- gbm.tune$results
### GBM Model Predictions and Performance
# Make predictions using the test data set
gbm.pred <- predict(gbm.tune,testX)
#Look at the confusion matrix  
#Draw the ROC curve 
gbm.probs <- predict(gbm.tune,testX,type="prob")
gbm.ROC <- roc(predictor=gbm.probs$PS,
#Area under the curve: 0.8731
plot(gbm.ROC,main="GBM ROC")
# Plot the propability of poor segmentation
histogram(~gbm.probs$PS|testData$Class,xlab="Probability of Poor Segmentation")
# Some stackexchange guidance for xgboost
# Set up for parallel procerssing
# Train xgboost
xgb.grid <- expand.grid(nrounds = 500, #the maximum number of iterations
                        eta = c(0.01,0.1), # shrinkage
                        max_depth = c(2,6,10))
xgb.tune <-train(x=trainX,y=trainData$Class,
plot(xgb.tune)  		# Plot the performance of the training models
res <- xgb.tune$results
### xgboostModel Predictions and Performance
# Make predictions using the test data set
xgb.pred <- predict(xgb.tune,testX)
#Look at the confusion matrix  
#Draw the ROC curve 
xgb.probs <- predict(xgb.tune,testX,type="prob")
xgb.ROC <- roc(predictor=xgb.probs$PS,
# Area under the curve: 0.8857
plot(xgb.ROC,main="xgboost ROC")
# Plot the propability of poor segmentation
histogram(~xgb.probs$PS|testData$Class,xlab="Probability of Poor Segmentation")
# Comparing Multiple Models
# Having set the same seed before running gbm.tune and xgb.tune
# we have generated paired samples and are in a position to compare models 
# using a resampling technique.
# (See Hothorn at al, "The design and analysis of benchmark experiments
# -Journal of Computational and Graphical Statistics (2005) vol 14 (3) 
# pp 675-699) 
rValues <- resamples(list(xgb=xgb.tune,gbm=gbm.tune))
bwplot(rValues,metric="ROC",main="GBM vs xgboost")	# boxplot
dotplot(rValues,metric="ROC",main="GBM vs xgboost")	# dotplot







To leave a comment for the author, please follow the link and comment on their blog: Revolutions. offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...

If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.


Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)