**Yet Another Blog in Statistical Computing » S+/R**, and kindly contributed to R-bloggers)

The feed-forward neural network is a very powerful classification model in the machine learning content. Since the goodness-of-fit of a neural network is majorly dominated by the model complexity, it is very tempting for a modeler to over-parameterize the neural network by using too many hidden layers or/and hidden units.

As pointed out by Brian Ripley in his famous book “Modern Applied Statistics with S”, the complexity of a neural network can be regulated by a hyper-parameter called “weight decay” to penalize the weights of hidden units. Per Ripley, the use of weight decay can both help the optimization process and avoid the over-ﬁtting.

Up till now, it becomes clear that the balance between the network complexity and the size of weight decay should form the optimal setting for a neural network. The only question remained is how to identify such a combination. In the real world, practitioners usually would use v-folder or cross-sample validation. However, given the expensive computing cost of a neural network, the cross-sample validation seems more efficient then the v-folder. In addition, due to the presence of local minimum, the validation result from a set of averaged models instead of a single model is deemed more reliable.

The example below shows a grip search strategy for the optimal setting in a neural network by cross-sample validation. As suggested by Ripley, the weight decay is in the approximate range between 0.01 and 0.1 for the entropy fit. For the simplicity, just a few numbers of hidden units are tried. However, with the availability of computing power, a finer grip search for a good combination between weight decay and the number of hidden units would be highly recommended.

> # DATA PREPARATIONS > df1 <- read.csv('credit_count.csv') > df2 <- df1[df1$CARDHLDR == 1, 2:12] > X <- I(as.matrix(df2[-1])) > st.X <- scale(X) > Y <- I(as.matrix(df2[1])) > df3 <- data.frame(X = st.X, Y); > > # DIVIDE DATA INTO TESTING AND TRAINING SETS > set.seed(2013) > rows <- sample(1:nrow(df3), nrow(df3) - 1000) > set1 <- df3[rows, ] > set2 <- df3[-rows, ] > > result <- c(NULL, NULL, NULL, NULL, NULL) > n_nets <- 10 > # SEARCH FOR OPTIMAL WEIGHT DECAY > for (w in c(0.01, 0.05, 0.1)) + { + # SEARCH FOR OPTIMAL NUMBER OF HIDDEN UNITS + for (n in c(1, 5, 10, 20)) + { + # CREATE A VECTOR OF RANDOM SEEDS + rv <- round(runif(n_nets) * 100) + # FOR EACH SETTING, RUN NEURAL NET MULTIPLE TIMES + for (i in 1:n_nets) + { + # INITIATE THE RANDOM STATE FOR EACH NET + set.seed(rv[i]); + # TRAIN NEURAL NETS + net <- nnet::nnet(Y ~ X, size = n, data = set1, entropy = TRUE, maxit = 1000, decay = w, skip = TRUE, trace = FALSE) + # COLLECT PREDICTIONS TO DO MODEL AVERAGING + if (i == 1) prob <- predict(net, set2) else prob <- prob + predict(net, set2) + } + # CALCULATE AREA UNDER CURVE OF THE MODEL AVERAGING PREDICTION + roc <- verification::roc.area(set2$Y, prob / n_nets)[1] + # COLLECT RESULTS + result <- rbind(result, c(w, n, roc, round(mean(prob / n_nets), 4), round(mean(set2$Y), 4))) + } + } > result2 <- data.frame(wt_decay = unlist(result[, 1]), n_units = unlist(result[, 2]),auc = unlist(result[, 3]), + pred_rate = unlist(result[, 4]), obsv_rate = unlist(result[, 5])) > result2[order(result2$auc, decreasing = T), ] wt_decay n_units auc pred_rate obsv_rate 1 0.01 1 0.6638209 0.0923 0.095 9 0.10 1 0.6625414 0.0923 0.095 5 0.05 1 0.6557022 0.0922 0.095 3 0.01 10 0.6530154 0.0938 0.095 8 0.05 20 0.6528293 0.0944 0.095 6 0.05 5 0.6516662 0.0917 0.095 2 0.01 5 0.6498284 0.0928 0.095 7 0.05 10 0.6456063 0.0934 0.095 4 0.01 20 0.6446176 0.0940 0.095 10 0.10 5 0.6434545 0.0927 0.095 12 0.10 20 0.6415935 0.0938 0.095 11 0.10 10 0.6348822 0.0928 0.095

**leave a comment**for the author, please follow the link and comment on their blog:

**Yet Another Blog in Statistical Computing » S+/R**.

R-bloggers.com offers

**daily e-mail updates**about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...