DART: Dropout Regularization in Boosting Ensembles

August 20, 2017
By

[This article was first published on S+/R – Yet Another Blog in Statistical Computing, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

The dropout approach developed by Hinton has been widely employed in deep learnings to prevent the deep neural network from overfitting, as shown in https://statcompute.wordpress.com/2017/01/02/dropout-regularization-in-deep-neural-networks.

In the paper http://proceedings.mlr.press/v38/korlakaivinayak15.pdf, the dropout can also be used to address the overfitting in boosting tree ensembles, e.g. MART, caused by the so-called “over-specialization”. In particular, while first few trees added at the beginning of ensembles would dominate the model performance, the rest added later can only improve the prediction for a small subset, which increases the risk of overfitting. The idea of DART is to build an ensemble by randomly dropping boosting tree members. The percentage of dropouts can determine the degree of regularization for boosting tree ensembles.

Below is a demonstration showing the implementation of DART with the R xgboost package. First of all, after importing the data, we divided it into two pieces, one for training and the other for testing.

pkgs <- c('pROC', 'xgboost')
lapply(pkgs, require, character.only = T)
df1 <- read.csv("Downloads/credit_count.txt")
df2 <- df1[df1$CARDHLDR == 1, ]
set.seed(2017)
n <- nrow(df2)
sample <- sample(seq(n), size = n / 2, replace = FALSE)
train <- df2[sample, -1]
test <- df2[-sample, -1]

For the comparison purpose, we first developed a boosting tree ensemble without dropouts, as shown below. For the simplicity, all parameters were chosen heuristically. The max_depth is set to 3 due to the fact that the boosting tends to work well with so-called “weak” learners, e.g. simple trees. While ROC for the training set can be as high as 0.95, ROC for the testing set is only 0.60 in our case, implying the overfitting issue.

mart.parm <- list(booster = "gbtree", nthread = 4, eta = 0.1, max_depth = 3, subsample = 1, eval_metric = "auc")
mart <- xgboost(data = as.matrix(train[, -1]), label = train[, 1], params = mart.parm, nrounds = 500, verbose = 0, seed = 2017)
pred1 <- predict(mart, as.matrix(train[, -1]))
pred2 <- predict(mart, as.matrix(test[, -1]))
roc(as.factor(train$DEFAULT), pred1)
# Area under the curve: 0.9459
roc(as.factor(test$DEFAULT), pred2)
# Area under the curve: 0.6046

With the same set of parameters, we refitted the ensemble with dropouts, e.g. DART. As shown below, by dropping 10% tree members, ROC for the testing set can increase from 0.60 to 0.65. In addition, the performance disparity between training and testing sets with DART decreases significantly.

dart.parm <- list(booster = "dart", rate_drop = 0.1, nthread = 4, eta = 0.1, max_depth = 3, subsample = 1, eval_metric = "auc")
dart <- xgboost(data = as.matrix(train[, -1]), label = train[, 1], params = dart.parm, nrounds = 500, verbose = 0, seed = 2017)
pred1 <- predict(dart, as.matrix(train[, -1]))
pred2 <- predict(dart, as.matrix(test[, -1]))
roc(as.factor(train$DEFAULT), pred1)
# Area under the curve: 0.7734
roc(as.factor(test$DEFAULT), pred2)
# Area under the curve: 0.6517

Besides rate_drop = 0.1, a wide range of dropout rates have also been tested. In most cases, DART outperforms its counterpart without the dropout regularization.

To leave a comment for the author, please follow the link and comment on their blog: S+/R – Yet Another Blog in Statistical Computing.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Search R-bloggers

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)