relgam: Fitting reluctant generalized additive models

[This article was first published on R – Statistical Odds & Ends, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

I’m proud to announce that my latest research project, reluctant generalized additive modeling (RGAM), is complete (for now)! In this post, I give a brief overview of the method: what it is trying to do and how you can fit such a model in R. (This project is joint work with my advisor, Rob Tibshirani.)

  • For an in-depth description of the method, please see our arXiv preprint.
  • You can download the CRAN version of the package, relgam, here. The latest version of the package is on Github.
  • For more details on how to use the package, please see the package’s vignette.

Introduction and motivation

tl;dr: Reluctant generalized additive modeling (RGAM) produces highly interpretable sparse models which allow non-linear relationships between the response and each individual feature. However, non-linear relationships are only included if deemed important in improving prediction performance. RGAMs working with quantitative, binary, count and survival responses and is computationally efficient.

Consider the supervised learning setting, where we have n observations of p features {\bf X} = \{x_{ij} \} for i = 1,2, \dots,n and j = 1,2 \dots,p, along with n responses y = (y_1, \dots, y_n). Let X_j \in \mathbb{R}^n denote the values of the jth feature. Generalized linear models (GLMs) assume that the relationship between the response and the features is

\begin{aligned} \eta(y) = \sum_{j=1}^p \beta_j X_j +\epsilon, \end{aligned}

where \eta is a link function and \epsilon is a mean-zero error term. Generalized additive models (GAMs) are a more flexible class of models, assuming the true relationship to be

\begin{aligned} \eta(y) = \sum_{j=1}^p f_j(X_j) +\epsilon, \end{aligned}

where the f_j‘s are unknown functions to be determined by the model.

These two classes of models include all p features in the model which is often undesirable, especially when we have tons of features. (We usually expect only a small fraction of features to have any influence on the response variable.) This is especially problematic with GAMs as overfitting can occur much more easily. A host of methods have arisen to create sparse GAMs, i.e. GAMs that involve only a handful of features. Earlier examples of such examples include COSSO (Lin & Zhang 2006) and SpAM (Ravikumar et al. 2007).

While providing sparsity, these methods dictated that the features included had to have a non-linear relationship with the response even if a linear relationship would have been sufficient to capture the relationship. More sophisticated methods were developed to give both sparsity and the possibility of linear or non-linear relationships between the features and response. Examples of such methods are GAMSEL (Chouldechova & Hastie 2015), SPLAM (Lou et al. 2016) and SPLAT (Petersen & Witten 2019). GAMSEL is available on R in the gamsel package (see my unofficial vignette here) and I was not able to find R packages for the other two methods.

Reluctant generalized additive models (RGAM) fall in the same class as these last group of methods. It is available on R in the relgam package. RGAMs are computationally fast and work with quantitative, binary, count and survival response variables. (To my knowledge, existing software only works for quantitative and binary variables.)

RGAM: What is it?

Reluctant generalizing additive modeling was inspired by reluctant interaction modeling (Yu et al. 2019). The idea is that

One should prefer a linear term over a non-linear term if all else is equal.

That is, we prefer a model to contain only effects that are linear in the original set of features: non-linearities are only included thereafter if they add to predictive performance.

We operationalize this principle with a three-step process that closely mimics that of reluctant interaction modeling. At a high level:

  1. Fit the response as well as we can using only the main effects (i.e. original features).
  2. For each original feature X_j, construct a non-linear feature associated with it.
  3. Refit the response on all the main effects and the additional features from Step 2.

Now for a little bit more detail:

  1. Fit the lasso of y on \bf X to get coefficients \hat{\beta}. Compute the residuals r = y - {\bf X}\hat{\beta}, using the \lambda hyperparameter chosen by cross-validation.
  2. For each j = 1, \dots, p, fit a smoothing spline with d degrees of freedom of r on X_j which we denote by \hat{f}_j. Rescale \hat{f}_j so that \overline{\text{sd}}(\hat{f}_j) / \text{mean}(\overline{\text{sd}}(X_j)) = \gamma. Let \bf F denote the matrix whose columns are the \hat{f}_j‘s.
  3. Fit the lasso of y on \bf X and \bf F for a path of tuning parameters \lambda_1 > \dots > \lambda_m \geq 0.

There are three hyperparameters here: \lambda (just like the lasso), d for the smoothing spline degrees of freedom in Step 2, and \gamma for scaling the non-linear features. The role of \gamma might be a bit hard to understand from the technical description above. Informally, \gamma = 1 means that the linear and non-linear features are on the same scale. \gamma < 1 means that the non-linear features are on a smaller scale: as a result, the coefficient associated with them is less likely to survive variable selection by the lasso in Step 3.

A simple example

The CRAN vignette is the best place to start learning how to fit RGAMs in practice. Below I give an example of the types of models that can come out of RGAM. (Code for this example can be found here.)

We simulate data with n = 100 observations and p = 12 features. Each entry in the \bf X matrix is an independent draw from the standard normal distribution, and the true response is

\begin{aligned} y = X_1 + X_2 + X_3 + \left( 2X_4^2 + 4X_4 - 2 \right) + \left( -2X_5^2 + 2 \right) + \frac{1}{2}X_6^3 + \epsilon. \end{aligned}

We fit a RGAM to this data for a sequence of \lambda values. The larger the \lambda index, the smaller the \lambda value, the less penalty imposed in the lasso in Step 3, resulting in more flexible models.

For each \lambda value, RGAM’s predictions have the form

\begin{aligned} \hat{y} = \sum_{j=1}^p \left( \hat{\beta}_j X_j + \hat{f}_j (X_j) \right). \end{aligned}

Let \hat{g}_j(X_j) = \hat{\beta}_j X_j + \hat{f}_j (X_j). We plot the model fits for the first 30 \lambda values in the animation below. In each of the 12 panes, we see the estimated \hat{g}_j for each variable (in blue, green or red), and the true relationship g_j in black.

For small \lambda indices (i.e. large \lambda values), we have very restricted models, with most \hat{g}_j‘s being zero or linear. As the \lambda index increases, we see that the RGAM model fits get closer and closer to the true relationships. Past some \lambda index, we start to see some overfitting going on. The optimal value of \lambda can be chosen via methods like cross-validation.

Give it a try!

I think RGAM is a neat extension to GAM and other sparse additive models. It may not always perform best but I think it is a nice tool to add to your arsenal of interpretable models to try for supervised learning problems!

To leave a comment for the author, please follow the link and comment on their blog: R – Statistical Odds & Ends.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)