Site icon R-bloggers

You Don’t Need to Learn All the Weights on tabular data: The Case for rvflnet (a nonlinear expressive glmnet) on regression, classification and survival analysis

[This article was first published on T. Moudiki's Webpage - R, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Introduction

Random Vector Functional Link (RVFL) networks offer a simple yet powerful alternative to traditional neural networks for tabular data. Instead of learning hidden layers through backpropagation, RVFL generates them randomly (or not, if using a deterministic sequence of quasi-random numbers) and focuses all learning effort on a final, regularized linear model.

Formally, let

\[X \in \mathbb{R}^{n \times p}\]

be the input data. RVFL networks (the ones described in this blog post) construct a set of nonlinear features by projecting (X) onto a random matrix

\[W \in \mathbb{R}^{p \times m},\]

and applying an activation function (\(g(\cdot)\)):

\[H = g\left( \frac{X – \mu}{\sigma} ; W \right).\]

These random nonlinear features are then concatenated with the original inputs to form an augmented design matrix:

\[Z = [X | H].\]

The model prediction is obtained by fitting a linear model on this expanded space (hence, a nonlinear GLM):

\[\hat{y} = Z \beta.\]

Because (Z) can be high-dimensional and highly redundant, RVFL networks (the ones described in this blog post) rely on Elastic Net regularization (glmnet) to estimate the coefficients:

\[\hat{\beta} = \arg\min_{\beta}\mathcal{L}(y, Z\beta) + \lambda \left(\alpha ||\beta||_1 + (1-\alpha)||\beta||_2^2\right).\]

In this framework, randomness creates a rich pool of nonlinear transformations, while regularization selects and stabilizes the most useful ones. The result is a nonlinear model that combines the flexibility of neural networks with the efficiency and robustness of linear methods.

Of course, this blog post is not a proof of the title. It’s about R package rvflnet. But you can appreciate the high performance of RVFLs on regression, classification and survival analysis, an notably on the controversial Boston dataset (performs on par with Random Forest or Gradient Boosting).

0 – Install package

install.packages("survival", repos = "https://cran.r-project.org") # survival analysis

install.packages("remotes", repos = "https://cran.r-project.org")

devtools::install_github('thierrymoudiki/rvflnet') # Nonlinear glm (RVFL networks)

1 – Regression

set.seed(123)

library(glmnet)
data(Boston, package = "MASS")

# -------------------------
# Data
# -------------------------
X <- as.matrix(Boston[, -14])
y <- Boston$medv

n <- nrow(X)
idx <- sample(1:n, size = round(0.8 * n))

X_train <- X[idx, ]
y_train <- y[idx]

X_test <- X[-idx, ]
y_test <- y[-idx]

# -------------------------
# Grid
# -------------------------
grid <- expand.grid(
  n_hidden = c(175, 200, 225, 250),
  alpha = seq(0.1, 0.5, by=0.2),
  include_original = c(TRUE, FALSE),
  seed = 1,
  stringsAsFactors = FALSE
)

results <- vector("list", nrow(grid))

# -------------------------
# Loop
# -------------------------
for (i in seq_len(nrow(grid))) {

  params <- grid[i, ]

  #cat("\n========================================\n")
  #cat(sprintf("Run %d / %d\n", i, nrow(grid)))
  #print(params)

  # -------------------------
  # Fit model
  # -------------------------
  fit <- rvflnet::rvflnet(
    X_train, y_train,
    n_hidden = params$n_hidden,
    activation = "sigmoid",
    W_type = "gaussian",
    seed = params$seed,
    include_original = params$include_original, # direct link, skip connection or not
    alpha = params$alpha
  )

  # -------------------------
  # Evaluate full lambda path
  # -------------------------
  lambdas <- fit$fit$lambda

  preds <- predict(fit, newx = X_test, s = lambdas)

  rmse_path <- sqrt(colMeans((preds - y_test)^2))

  best_idx <- which.min(rmse_path)

  best_rmse <- rmse_path[best_idx]
  best_lambda <- lambdas[best_idx]

  # -------------------------
  # Sparsity
  # -------------------------
  coef_mat <- coef(fit, s = best_lambda)
  nonzero <- sum(coef_mat[-1, 1] != 0)

  # -------------------------
  # Verbose output
  # -------------------------
  #cat(sprintf("Best RMSE: %.4f\n", best_rmse))
  #cat(sprintf("Best lambda: %.6f\n", best_lambda))
  #cat(sprintf("Non-zero coeffs: %d\n", nonzero))

  # -------------------------
  # Store
  # -------------------------
  results[[i]] <- data.frame(
    n_hidden = params$n_hidden,
    alpha = params$alpha,
    include_original = params$include_original,
    seed = params$seed,
    rmse = best_rmse,
    lambda = best_lambda,
    nonzero = nonzero
  )
}

# -------------------------
# Aggregate
# -------------------------
results_df <- do.call(rbind, results)
results_df <- results_df[order(results_df$rmse), ]
print(head(results_df))

Loading required package: Matrix

Loaded glmnet 4.1-10



               n_hidden alpha include_original seed     rmse     lambda nonzero
s= 0.027561759      200   0.1             TRUE    1 2.881935 0.02756176     190
s= 0.017620327      200   0.3             TRUE    1 2.884739 0.01762033     167
s= 0.012734248      200   0.5             TRUE    1 2.889339 0.01273425     158
s= 0.036435024      175   0.1             TRUE    1 2.920012 0.03643502     165
s= 0.016833926      175   0.5             TRUE    1 2.938472 0.01683393     136
s= 0.023293035      175   0.3             TRUE    1 2.941267 0.02329304     144

An RMSE of 2.88 is on par with Random Forest or Gradient Boosting, with a significantly faster computation time.

2 – Classification

2 – 1 Binary Classification

set.seed(123)

data(iris)

# Binary classification: setosa vs others
y <- ifelse(iris$Species == "setosa", 1, 0)
X <- as.matrix(iris[, 1:4])

# Train/test split
n <- nrow(X)
idx <- sample(1:n, size = round(0.8 * n))

X_train <- X[idx, ]
y_train <- y[idx]

X_test <- X[-idx, ]
y_test <- y[-idx]

# -------------------------
# Fit model
# -------------------------
cv_model <- rvflnet::cv.rvflnet(
  X_train, y_train,
  n_hidden = 50,
  activation = "relu",
  W_type = "gaussian",
  family = "binomial",
  nfolds = 5
)

# -------------------------
# Predictions (probabilities)
# -------------------------
(probs <- predict(cv_model, X_test, type = "response"))

# Convert to class
y_pred <- ifelse(probs > 0.5, 1, 0)

all.equal(as.numeric(y_pred), as.numeric(predict(cv_model, X_test, type="class")))

# -------------------------
# Diagnostics
# -------------------------

# Accuracy
acc <- mean(drop(y_pred) == y_test)
cat("Accuracy:", acc, "\n")

# Confusion matrix
table(Predicted = y_pred, Actual = y_test)
A matrix: 30 × 1 of type dbl
lambda.min
0.9997617002
0.9992267955
0.9997120678
0.9997524867
0.9996600481
0.9992472082
0.9996101744
0.9999356520
0.9998139568
0.9995418762
0.0003328885
0.0003328885
0.0003328885
0.0019937012
0.0003328885
0.0005459970
0.0003328885
0.0005035848
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885
0.0003328885

TRUE

Accuracy: 1 



         Actual
Predicted  0  1
        0 20  0
        1  0 10

2 – 2 Multiclass Classification

set.seed(123)

data(iris)

y <- as.numeric(iris$Species)
X <- as.matrix(iris[, 1:4])

# Train/test split
n <- nrow(X)
idx <- sample(1:n, size = round(0.8 * n))

X_train <- X[idx, ]
y_train <- y[idx]

X_test <- X[-idx, ]
y_test <- y[-idx]

# -------------------------
# Fit model
# -------------------------
cv_model <- rvflnet::rvflnet(
  X_train, y_train,
  n_hidden = 50,
  activation = "relu",
  W_type = "gaussian",
  family = "multinomial",
  nlambda = 25,
  nfolds = 5
)

# -------------------------
# Diagnostics
# -------------------------

# Accuracy
acc <- colMeans(predict(cv_model, X_test, type="class") == y_test)
cat("Accuracies:", acc, "\n") # consider other metrics

Accuracies: 0.1666667 0.7666667 0.9333333 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 0.9666667 

3 – Nonlinear Cox survival analysis

3 – 1 Example 1

library(survival)
library(rvflnet)

data(ovarian)

X <- as.matrix(ovarian[, c("age", "resid.ds", "rx", "ecog.ps")])
y <- Surv(ovarian$futime, ovarian$fustat)

set.seed(123)
n <- nrow(X)
train_idx <- sample(1:n, size = round(0.8 * n))

X_train <- X[train_idx, ]
X_test  <- X[-train_idx, ]
y_train <- y[train_idx]
y_test  <- y[-train_idx]

# -------------------------
# Fit model
# -------------------------
cv_fit <- rvflnet::cv.rvflnet(
  X_train, y_train,
  family = "cox",
  nfolds = 5,
  type.measure = "C"
)

plot(cv_fit)

# Out-of-sample C-index
print(glmnet::Cindex(pred = predict(cv_fit, X_test), y = y_test))


Warning message in data(ovarian):
“data set ‘ovarian’ not found”


[1] 0.8571429

3 – 2 Example 2

library(glmnet)
library(survival)

data(pbc)
pbc2       <- pbc[!is.na(pbc$trt), ]
pbc2$event <- as.integer(pbc$status[!is.na(pbc$trt)] == 2)
pbc2$sex_n <- as.integer(pbc2$sex == "f")

feat_cols <- c("trt","age","sex_n","ascites","hepato","spiders","edema",
               "bili","chol","albumin","copper","alk.phos","ast",
               "trig","platelet","protime","stage")

df <- pbc2[, c("time", "event", feat_cols)]
for (col in feat_cols)
  if (any(is.na(df[[col]])))
    df[[col]][is.na(df[[col]])] <- median(df[[col]], na.rm = TRUE)

set.seed(42)
idx_train <- sample(nrow(df), floor(0.75 * nrow(df)))
train <- df[idx_train, ]; test <- df[-idx_train, ]
X_tr  <- as.matrix(train[, feat_cols])
X_te  <- as.matrix(test[,  feat_cols])
y_tr   <- Surv(train$time, train$event)

fit <- rvflnet::rvflnet(
  X_tr, y_tr,
  family = "cox",
  alpha=0.1, lambda=0.1 # not recommended
)

y_te   <- Surv(test$time, test$event)
ci <- glmnet::Cindex(predict(fit, X_te), y_te)

cat("\n=== Test-set C-index ===\n")
print(ci)


=== Test-set C-index ===
[1] 0.8218117

fit <- rvflnet::rvflnet(
  X_tr, y_tr,
  family = "cox",
  alpha=0.1, nlambda=50
)

y_te   <- Surv(test$time, test$event)

(cis <- apply(predict(fit, X_te), 2, function(x) glmnet::Cindex(x, y_te)))

#cat("\n=== Test-set C-index ===\n")
plot(log(fit$fit$lambda), cis, type = 'l')
abline(h=0.8, lty=2, col="red")

<dl class=dl-inline><dt>s0</dt><dd>0.5</dd><dt>s1</dt><dd>0.762812872467223</dd><dt>s2</dt><dd>0.802145411203814</dd><dt>s3</dt><dd>0.811084624553039</dd><dt>s4</dt><dd>0.811680572109654</dd><dt>s5</dt><dd>0.814064362336114</dd><dt>s6</dt><dd>0.815852205005959</dd><dt>s7</dt><dd>0.817640047675805</dd><dt>s8</dt><dd>0.820023837902265</dd><dt>s9</dt><dd>0.81942789034565</dd><dt>s10</dt><dd>0.817640047675805</dd><dt>s11</dt><dd>0.81823599523242</dd><dt>s12</dt><dd>0.81823599523242</dd><dt>s13</dt><dd>0.815852205005959</dd><dt>s14</dt><dd>0.814660309892729</dd><dt>s15</dt><dd>0.813468414779499</dd><dt>s16</dt><dd>0.813468414779499</dd><dt>s17</dt><dd>0.815852205005959</dd><dt>s18</dt><dd>0.814660309892729</dd><dt>s19</dt><dd>0.82061978545888</dd><dt>s20</dt><dd>0.81942789034565</dd><dt>s21</dt><dd>0.82181168057211</dd><dt>s22</dt><dd>0.82061978545888</dd><dt>s23</dt><dd>0.817044100119189</dd><dt>s24</dt><dd>0.817640047675805</dd><dt>s25</dt><dd>0.81823599523242</dd><dt>s26</dt><dd>0.814660309892729</dd><dt>s27</dt><dd>0.810488676996424</dd><dt>s28</dt><dd>0.803933253873659</dd><dt>s29</dt><dd>0.802145411203814</dd><dt>s30</dt><dd>0.799761620977354</dd><dt>s31</dt><dd>0.793206197854589</dd><dt>s32</dt><dd>0.789034564958284</dd><dt>s33</dt><dd>0.777711561382598</dd><dt>s34</dt><dd>0.771156138259833</dd><dt>s35</dt><dd>0.766984505363528</dd><dt>s36</dt><dd>0.756853396901073</dd><dt>s37</dt><dd>0.748510131108462</dd><dt>s38</dt><dd>0.743146603098927</dd><dt>s39</dt><dd>0.735399284862932</dd><dt>s40</dt><dd>0.728843861740167</dd><dt>s41</dt><dd>0.721692491060787</dd><dt>s42</dt><dd>0.718116805721096</dd><dt>s43</dt><dd>0.717520858164482</dd><dt>s44</dt><dd>0.716924910607867</dd><dt>s45</dt><dd>0.716924910607867</dd><dt>s46</dt><dd>0.715733015494636</dd><dt>s47</dt><dd>0.716328963051251</dd><dt>s48</dt><dd>0.715137067938021</dd><dt>s49</dt><dd>0.713945172824791</dd></dl>

To leave a comment for the author, please follow the link and comment on their blog: T. Moudiki's Webpage - R.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Exit mobile version