A Grid Search for The Optimal Setting in Feed-Forward Neural Networks

February 3, 2013
By

(This article was first published on Yet Another Blog in Statistical Computing » S+/R, and kindly contributed to R-bloggers)

The feed-forward neural network is a very powerful classification model in the machine learning content. Since the goodness-of-fit of a neural network is majorly dominated by the model complexity, it is very tempting for a modeler to over-parameterize the neural network by using too many hidden layers or/and hidden units.

As pointed out by Brian Ripley in his famous book “Modern Applied Statistics with S”, the complexity of a neural network can be regulated by a hyper-parameter called “weight decay” to penalize the weights of hidden units. Per Ripley, the use of weight decay can both help the optimization process and avoid the over-fitting.

Up till now, it becomes clear that the balance between the network complexity and the size of weight decay should form the optimal setting for a neural network. The only question remained is how to identify such a combination. In the real world, practitioners usually would use v-folder or cross-sample validation. However, given the expensive computing cost of a neural network, the cross-sample validation seems more efficient then the v-folder. In addition, due to the presence of local minimum, the validation result from a set of averaged models instead of a single model is deemed more reliable.

The example below shows a grip search strategy for the optimal setting in a neural network by cross-sample validation. As suggested by Ripley, the weight decay is in the approximate range between 0.01 and 0.1 for the entropy fit. For the simplicity, just a few numbers of hidden units are tried. However, with the availability of computing power, a finer grip search for a good combination between weight decay and the number of hidden units would be highly recommended.

> # DATA PREPARATIONS
> df1 <- read.csv('credit_count.csv')
> df2 <- df1[df1$CARDHLDR == 1, 2:12]
> X <- I(as.matrix(df2[-1]))
> st.X <- scale(X)
> Y <- I(as.matrix(df2[1]))
> df3 <- data.frame(X = st.X, Y);
> 
> # DIVIDE DATA INTO TESTING AND TRAINING SETS
> set.seed(2013)
> rows <- sample(1:nrow(df3), nrow(df3) - 1000)
> set1 <- df3[rows, ]
> set2 <- df3[-rows, ]
> 
> result <- c(NULL, NULL, NULL, NULL, NULL)
> n_nets <- 10
> # SEARCH FOR OPTIMAL WEIGHT DECAY
> for (w in c(0.01, 0.05, 0.1))
+ {
+   # SEARCH FOR OPTIMAL NUMBER OF HIDDEN UNITS
+   for (n in c(1, 5, 10, 20))
+   {
+     # CREATE A VECTOR OF RANDOM SEEDS
+     rv <- round(runif(n_nets) * 100)
+     # FOR EACH SETTING, RUN NEURAL NET MULTIPLE TIMES
+     for (i in 1:n_nets)
+     {
+       # INITIATE THE RANDOM STATE FOR EACH NET
+       set.seed(rv[i]);
+       # TRAIN NEURAL NETS
+       net <- nnet::nnet(Y ~ X, size = n, data = set1, entropy = TRUE, maxit = 1000, decay = w, skip = TRUE, trace = FALSE)
+       # COLLECT PREDICTIONS TO DO MODEL AVERAGING
+       if (i == 1) prob <- predict(net, set2) else prob <- prob + predict(net, set2)
+     }
+     # CALCULATE AREA UNDER CURVE OF THE MODEL AVERAGING PREDICTION
+     roc <- verification::roc.area(set2$Y, prob / n_nets)[1]
+     # COLLECT RESULTS
+     result <- rbind(result, c(w, n, roc, round(mean(prob / n_nets), 4), round(mean(set2$Y), 4)))
+   }
+ } 
> result2 <- data.frame(wt_decay = unlist(result[, 1]), n_units = unlist(result[, 2]),auc = unlist(result[, 3]),
+                       pred_rate = unlist(result[, 4]), obsv_rate = unlist(result[, 5]))
> result2[order(result2$auc, decreasing = T), ]
   wt_decay n_units       auc pred_rate obsv_rate
1      0.01       1 0.6638209    0.0923     0.095
9      0.10       1 0.6625414    0.0923     0.095
5      0.05       1 0.6557022    0.0922     0.095
3      0.01      10 0.6530154    0.0938     0.095
8      0.05      20 0.6528293    0.0944     0.095
6      0.05       5 0.6516662    0.0917     0.095
2      0.01       5 0.6498284    0.0928     0.095
7      0.05      10 0.6456063    0.0934     0.095
4      0.01      20 0.6446176    0.0940     0.095
10     0.10       5 0.6434545    0.0927     0.095
12     0.10      20 0.6415935    0.0938     0.095
11     0.10      10 0.6348822    0.0928     0.095

To leave a comment for the author, please follow the link and comment on their blog: Yet Another Blog in Statistical Computing » S+/R.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)