twidlr: data.frame-based API for model and predict functons

May 2, 2017
By

(This article was first published on blogR, and kindly contributed to R-bloggers)

@drsimonj here to introduce my latest tidy-modelling package for R, “twidlr”. twidlr wraps model and predict functions you already know and love with a consistent data.frame-based API!

All models wrapped by twidlr can be fit to data and used to make predictions as follows:

library(twidlr)

fit <- model(data, formula, ...)
predict(fit, data, ...)
  • data is a data.frame (or object that can be corced to one) and is required
  • formula describes the model to be fit

 The motivation

The APIs of model and predict functions in R are inconsistent and messy.

Some models like linear regression want a formula and data.frame:

lm(hp ~ ., mtcars)

Models like gradient-boosted decision trees want vectors and matrices:

library(xgboost)

y <- mtcars$hp
x <- as.matrix(mtcars[names(mtcars) != "hp"])

xgboost(x, y, nrounds = 5)

Models like generalized linear models want you to work. For example, to create interactions and dummy-coded variables:

library(glmnet)

y <- iris$Petal.Length
x <- model.matrix(Petal.Length ~ Sepal.Width * Sepal.Length + Species, iris)

glmnet(x, y)

Some models like k-means don’t have a corresponding predict function:

fit <- kmeans(iris[1:120,-5], centers = 3)
predict(fit, iris[121:150,])

## Error in UseMethod("predict") : 
##   no applicable method for 'predict' applied to an object of class "kmeans"

Some predict functions are impure and return unexpected results. For example, linear discriminant analysis:

library(MASS)

d <- iris
fit <- lda(Species ~ ., d)

table(predict(fit)$class)
#> 
#>     setosa versicolor  virginica 
#>         50         49         51

d <- d[1:10,]

table(predict(fit)$class)
#> 
#>     setosa versicolor  virginica 
#>         10          0          0

 ~ twidlr

twidlr helps to solve these problems by wrapping model and predict functions you already know and love with a consistent data.frame-based API!

Load twidlr and your favourite models can be fit to a data.frame with a formula and any additional arguments! To demonstrate, compare API to above:

library(twidlr)

lm(mtcars, hp ~ .)
xgboost(mtcars, hp ~ ., nrounds = 5)
glmnet(iris, Petal.Length ~ Sepal.Width * Sepal.Length + Species)

What’s more, predictions can be made with all fitted models via predict and a data.frame. This even works for models that don’t traditionally have a predict method:

library(twidlr)

fit <- kmeans(iris[1:140,-5], centers = 3)
predict(fit, iris[141:150,])
#>  [1] 3 3 3 3 3 3 3 3 3 3

 Bonus example

Although useful in itself, a consistent data.frame-based API expands the capabilities of other tidy and data.frame-based packages like the tidyverse packages and pipelearner.

For the motivated, this demonstrates how to fit multiple models and compare their RMSE on new data. It’s streamlined because purrr’s map functions can exploit the consistent API for each model and predict.

library(twidlr)
library(purrr)

train <- cars[ 1:40, ]
test  <- cars[41:50, ]
f <- c("lm", "randomForest", "rpart")

# Fit each model to training data and compute RMSE on test data
rmse <- invoke_map(f, data = train, formula = speed ~ dist) %>%
  map(predict, data = test) %>% 
  map_dbl(~ sqrt(mean((. - test$speed)^2)))

setNames(rmse, f)
#>           lm randomForest        rpart 
#>     3.832426     6.129539     6.034932

If you can’t see the value, try doing this without twidlr.

 Take home messages

twidlr attempts to brings the follwing to modelling in R:

  • a consistent and tidy model APIs
  • pure and available predict functions
  • the power of formula operators
  • tidyverse philosophy (eg keep piping)

But twidlr is new, and needs your help to grow. So if your favourite model isn’t listed here, fork twidlr on GitHub and add it to help improve modelling in R! Advice for contributing can be found here.

Thanks already to Joran Elias and Mathew Ling for their contributions!

 Sign off

Thanks for reading and I hope this was useful for you.

For updates of recent blog posts, follow @drsimonj on Twitter, or email me at [email protected] to get in touch.

If you’d like the code that produced this blog, check out the blogR GitHub repository.

To leave a comment for the author, please follow the link and comment on their blog: blogR.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)