A Grid Search for The Optimal Setting in Feed-Forward Neural Networks

[This article was first published on Yet Another Blog in Statistical Computing » S+/R, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

The feed-forward neural network is a very powerful classification model in the machine learning content. Since the goodness-of-fit of a neural network is majorly dominated by the model complexity, it is very tempting for a modeler to over-parameterize the neural network by using too many hidden layers or/and hidden units.

As pointed out by Brian Ripley in his famous book “Modern Applied Statistics with S”, the complexity of a neural network can be regulated by a hyper-parameter called “weight decay” to penalize the weights of hidden units. Per Ripley, the use of weight decay can both help the optimization process and avoid the over-fitting.

Up till now, it becomes clear that the balance between the network complexity and the size of weight decay should form the optimal setting for a neural network. The only question remained is how to identify such a combination. In the real world, practitioners usually would use v-folder or cross-sample validation. However, given the expensive computing cost of a neural network, the cross-sample validation seems more efficient then the v-folder. In addition, due to the presence of local minimum, the validation result from a set of averaged models instead of a single model is deemed more reliable.

The example below shows a grip search strategy for the optimal setting in a neural network by cross-sample validation. As suggested by Ripley, the weight decay is in the approximate range between 0.01 and 0.1 for the entropy fit. For the simplicity, just a few numbers of hidden units are tried. However, with the availability of computing power, a finer grip search for a good combination between weight decay and the number of hidden units would be highly recommended.

> # DATA PREPARATIONS
> df1 <- read.csv('credit_count.csv')
> df2 <- df1[df1$CARDHLDR == 1, 2:12]
> X <- I(as.matrix(df2[-1]))
> st.X <- scale(X)
> Y <- I(as.matrix(df2[1]))
> df3 <- data.frame(X = st.X, Y);
> 
> # DIVIDE DATA INTO TESTING AND TRAINING SETS
> set.seed(2013)
> rows <- sample(1:nrow(df3), nrow(df3) - 1000)
> set1 <- df3[rows, ]
> set2 <- df3[-rows, ]
> 
> result <- c(NULL, NULL, NULL, NULL, NULL)
> n_nets <- 10
> # SEARCH FOR OPTIMAL WEIGHT DECAY
> for (w in c(0.01, 0.05, 0.1))
+ {
+   # SEARCH FOR OPTIMAL NUMBER OF HIDDEN UNITS
+   for (n in c(1, 5, 10, 20))
+   {
+     # CREATE A VECTOR OF RANDOM SEEDS
+     rv <- round(runif(n_nets) * 100)
+     # FOR EACH SETTING, RUN NEURAL NET MULTIPLE TIMES
+     for (i in 1:n_nets)
+     {
+       # INITIATE THE RANDOM STATE FOR EACH NET
+       set.seed(rv[i]);
+       # TRAIN NEURAL NETS
+       net <- nnet::nnet(Y ~ X, size = n, data = set1, entropy = TRUE, maxit = 1000, decay = w, skip = TRUE, trace = FALSE)
+       # COLLECT PREDICTIONS TO DO MODEL AVERAGING
+       if (i == 1) prob <- predict(net, set2) else prob <- prob + predict(net, set2)
+     }
+     # CALCULATE AREA UNDER CURVE OF THE MODEL AVERAGING PREDICTION
+     roc <- verification::roc.area(set2$Y, prob / n_nets)[1]
+     # COLLECT RESULTS
+     result <- rbind(result, c(w, n, roc, round(mean(prob / n_nets), 4), round(mean(set2$Y), 4)))
+   }
+ } 
> result2 <- data.frame(wt_decay = unlist(result[, 1]), n_units = unlist(result[, 2]),auc = unlist(result[, 3]),
+                       pred_rate = unlist(result[, 4]), obsv_rate = unlist(result[, 5]))
> result2[order(result2$auc, decreasing = T), ]
   wt_decay n_units       auc pred_rate obsv_rate
1      0.01       1 0.6638209    0.0923     0.095
9      0.10       1 0.6625414    0.0923     0.095
5      0.05       1 0.6557022    0.0922     0.095
3      0.01      10 0.6530154    0.0938     0.095
8      0.05      20 0.6528293    0.0944     0.095
6      0.05       5 0.6516662    0.0917     0.095
2      0.01       5 0.6498284    0.0928     0.095
7      0.05      10 0.6456063    0.0934     0.095
4      0.01      20 0.6446176    0.0940     0.095
10     0.10       5 0.6434545    0.0927     0.095
12     0.10      20 0.6415935    0.0938     0.095
11     0.10      10 0.6348822    0.0928     0.095

To leave a comment for the author, please follow the link and comment on their blog: Yet Another Blog in Statistical Computing » S+/R.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)