Kaggle Competition Walkthrough: Fitting a model

May 12, 2011
By

(This article was first published on Modern Tool Making, and kindly contributed to R-bloggers)

Now that we've got the data we need into R, it is very easy to fit a model using the caret package. Caret's workhorse function is called 'train,' and it allows you to fit a wide variety of models using the same syntax. Furthermore, many models have 'hyperparameters' that require tuning, such as the number of neighbors for a KNN model or the regularization parameters for an elastic net model. Caret tunes these parameters using a 'grid' search, and will either define a reasonable grid for you, or allow you to specify one yourself. For example, you might want to investigate KNN models with 3,5, and 10 neighbors. 'Train' will use cross-validation or bootstrap re-sampling to evaluate the predictive performance of each model, and then will fit the final model on your complete training dataset.

First, we define some parameters to pass to caret's 'train' function:

We use the 'trainControl' function to define train parameters. In this case, we want to use a repeated cross-validation re-sampling strategy, so we will employ a 10-fold cross-validation and repeat it 5 times. Furthermore, we want train to return class probabilities, because this competition is scored using the AUC metric, which requires probabilities. Finally, we use the twoClassSummary function, which calculates useful metrics for evaluating the predictive performance of a 2-class classification model. This final line is very important, as it will allow cart to evaluate our model using AUC.

Next, we fit our model:

You may recall in my last post, I defined a formula called 'FL,' which we are now using to specify our model. You could also specify the model using x,y form, where x is a matrix of independent variables and y is your dependent variables. This second method is a little bit faster, but I find using the formula interface makes my code easier to read and modify, which is FAR more important. Next we specify that we want to fit the model to the training set, and that we want to use the 'glmnet' model. Glmnet uses the elastic net, and performs quiet well on this dataset. Also, the Kaggle benchmark uses glmnet, so it is a good place to start. Next, we specify the metric, ROC (which really means AUC), by which the candidate models will be ranked. Then, we specify a custom tuning grid, which I found produces some nice results. You could also instead specify tuneLength=5 here to allow train to build its own grid, but in this case I prefer some finer control over the hyperparameters that get passed to 'train.' Finally, we pass the control object we defined earlier to the train function, and we're off!

After fitting the model, it is useful to look at how various glmnet hyperparameters affected its performance. Additionally, we can score on our test set, because we chose to use 'Target_Practice' as our target, and this Target has a known evaluate set.

We get an AUC of 0.8682391, which is pretty much equivalent to the benchmark of .87355. In my next post, I will show you how to beat the benchmark (and the majority of the other competators) by a significant margin.

Update: Furthermore, if you are on a linux machine, and wish to speed this process up, run the following code snippet after you define the parameters for your training function:
This will replace 'lapply' with 'mclapply' in caret's train function, and will allow you simultaneously fit as many models as your machine has cores. Isn't technology great? Also note that this code will not run in the Mac Gui. You need to open up the terminal and start R by typing 'R' to run your script in the console. The Mac gui does not handle 'multicore' well...

To leave a comment for the author, please follow the link and comment on his blog: Modern Tool Making.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.