Multiclass predictive modeling for #TidyTuesday NBER papers

[This article was first published on rstats | Julia Silge, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just getting started to tuning more complex models. Today’s screencast walks through how to build, tune, and evaluate a multiclass predictive model with text features and lasso regularization, with this week’s #TidyTuesday dataset on NBER working papers. 📑


Here is the code I used in the video, for those who prefer reading instead of or in addition to video.

Explore data

Our modeling goal is to predict the category of National Bureau of Economic Research working papers from the titles and years of the papers.

library(tidyverse)

papers <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-09-28/papers.csv")
programs <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-09-28/programs.csv")
paper_authors <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-09-28/paper_authors.csv")
paper_programs <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-09-28/paper_programs.csv")

Let’s start by joining up these datasets to find the info we need.

papers_joined <-
  paper_programs %>%
  left_join(programs) %>%
  left_join(papers) %>%
  filter(!is.na(program_category)) %>%
  distinct(paper, program_category, year, title)

papers_joined %>%
  count(program_category)

## # A tibble: 3 × 2
##   program_category        n
##   <chr>               <int>
## 1 Finance              4336
## 2 Macro/International 12012
## 3 Micro               18527

The papers are in three categories (finance, microeconomics, and macroeconomics) so we’ll be training a multiclass predictive model, not a binary classification model as we often see or use.

Let’s create one exploratory plot before we move on to modeling.

library(tidytext)
library(tidylo)

title_log_odds <-
  papers_joined %>%
  unnest_tokens(word, title) %>%
  filter(!is.na(program_category)) %>%
  count(program_category, word, sort = TRUE) %>%
  bind_log_odds(program_category, word, n)


title_log_odds %>%
  group_by(program_category) %>%
  slice_max(log_odds_weighted, n = 10) %>%
  ungroup() %>%
  ggplot(aes(log_odds_weighted,
    fct_reorder(word, log_odds_weighted),
    fill = program_category
  )) +
  geom_col(show.legend = FALSE) +
  facet_wrap(vars(program_category), scales = "free_y") +
  labs(x = "Log odds (weighted)", y = NULL)

These type of relationships between category and title words are what we want to use in our predictive model.

Build and tune a model

Let’s start our modeling by setting up our “data budget.” We’ll stratify by our outcome program_category.

library(tidymodels)

set.seed(123)
nber_split <- initial_split(papers_joined, strata = program_category)
nber_train <- training(nber_split)
nber_test <- testing(nber_split)

set.seed(234)
nber_folds <- vfold_cv(nber_train, strata = program_category)
nber_folds

## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits               id    
##    <list>               <chr> 
##  1 <split [23539/2617]> Fold01
##  2 <split [23539/2617]> Fold02
##  3 <split [23540/2616]> Fold03
##  4 <split [23540/2616]> Fold04
##  5 <split [23540/2616]> Fold05
##  6 <split [23541/2615]> Fold06
##  7 <split [23541/2615]> Fold07
##  8 <split [23541/2615]> Fold08
##  9 <split [23541/2615]> Fold09
## 10 <split [23542/2614]> Fold10

Next, let’s set up our feature engineering. We will need to transform our text data into features useful for our model by tokenizing and computing (in this case) tf-idf. Let’s also downsample since our dataset is imbalanced, with many more of some of the categories than others.

library(themis)
library(textrecipes)

nber_rec <-
  recipe(program_category ~ year + title, data = nber_train) %>%
  step_tokenize(title) %>%
  step_tokenfilter(title, max_tokens = 200) %>%
  step_tfidf(title) %>%
  step_downsample(program_category)

nber_rec

## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          2
## 
## Operations:
## 
## Tokenization for title
## Text filtering for title
## Term frequency-inverse document frequency with title
## Down-sampling based on program_category

Then, let’s create our model specification for a lasso model. We need to use a model specification that can handle multiclass data, in this case multinom_reg().

multi_spec <-
  multinom_reg(penalty = tune(), mixture = 1) %>%
  set_mode("classification") %>%
  set_engine("glmnet")

multi_spec

## Multinomial Regression Model Specification (classification)
## 
## Main Arguments:
##   penalty = tune()
##   mixture = 1
## 
## Computational engine: glmnet

Now it’s time to put the preprocessing and model together in a workflow().

nber_wf <- workflow(nber_rec, multi_spec)
nber_wf

## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: multinom_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 4 Recipe Steps
## 
## • step_tokenize()
## • step_tokenfilter()
## • step_tfidf()
## • step_downsample()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Multinomial Regression Model Specification (classification)
## 
## Main Arguments:
##   penalty = tune()
##   mixture = 1
## 
## Computational engine: glmnet

Since the lasso regularization penalty is a hyperparameter of the model (we can’t find the best value from fitting the model a single time), let’s tune over a grid of possible penalty parameters.

nber_grid <- grid_regular(penalty(range = c(-5, 0)), levels = 20)

doParallel::registerDoParallel()
set.seed(2021)
nber_rs <-
  tune_grid(
    nber_wf,
    nber_folds,
    grid = nber_grid
  )

nber_rs

## # Tuning results
## # 10-fold cross-validation using stratification 
## # A tibble: 10 × 4
##    splits               id     .metrics          .notes          
##    <list>               <chr>  <list>            <list>          
##  1 <split [23539/2617]> Fold01 <tibble [40 × 5]> <tibble [0 × 1]>
##  2 <split [23539/2617]> Fold02 <tibble [40 × 5]> <tibble [0 × 1]>
##  3 <split [23540/2616]> Fold03 <tibble [40 × 5]> <tibble [0 × 1]>
##  4 <split [23540/2616]> Fold04 <tibble [40 × 5]> <tibble [0 × 1]>
##  5 <split [23540/2616]> Fold05 <tibble [40 × 5]> <tibble [0 × 1]>
##  6 <split [23541/2615]> Fold06 <tibble [40 × 5]> <tibble [0 × 1]>
##  7 <split [23541/2615]> Fold07 <tibble [40 × 5]> <tibble [0 × 1]>
##  8 <split [23541/2615]> Fold08 <tibble [40 × 5]> <tibble [0 × 1]>
##  9 <split [23541/2615]> Fold09 <tibble [40 × 5]> <tibble [0 × 1]>
## 10 <split [23542/2614]> Fold10 <tibble [40 × 5]> <tibble [0 × 1]>

This is a pretty fast model to fit, since it is linear. How did it turn out?

autoplot(nber_rs)

show_best(nber_rs)

## # A tibble: 5 × 7
##    penalty .metric .estimator  mean     n std_err .config              
##      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
## 1 0.00234  roc_auc hand_till  0.784    10 0.00249 Preprocessor1_Model10
## 2 0.00428  roc_auc hand_till  0.783    10 0.00244 Preprocessor1_Model11
## 3 0.00127  roc_auc hand_till  0.783    10 0.00251 Preprocessor1_Model09
## 4 0.000695 roc_auc hand_till  0.782    10 0.00253 Preprocessor1_Model08
## 5 0.000379 roc_auc hand_till  0.782    10 0.00254 Preprocessor1_Model07

Choose and evaluate a final model

We could use the numerically best model with select_best() by often with regularized models we would rather choose a simpler model within some limits of performance. We can choose using the “one-standard error rule” with select_by_one_std_err() and then use last_fit() to fit one time to the training data and evaluate one time to the testing data.

final_penalty <-
  nber_rs %>%
  select_by_one_std_err(metric = "roc_auc", desc(penalty))

final_penalty

## # A tibble: 1 × 9
##   penalty .metric .estimator  mean     n std_err .config            .best .bound
##     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>              <dbl>  <dbl>
## 1 0.00428 roc_auc hand_till  0.783    10 0.00244 Preprocessor1_Mod… 0.784  0.781

final_rs <-
  nber_wf %>%
  finalize_workflow(final_penalty) %>%
  last_fit(nber_split)

final_rs

## # Resampling results
## # Manual resampling 
## # A tibble: 1 × 6
##   splits               id               .metrics  .notes  .predictions .workflow
##   <list>               <chr>            <list>    <list>  <list>       <list>   
## 1 <split [26156/8719]> train/test split <tibble … <tibbl… <tibble [8,… <workflo…

How did our final model perform on the training data?

collect_metrics(final_rs)

## # A tibble: 2 × 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy multiclass     0.609 Preprocessor1_Model1
## 2 roc_auc  hand_till      0.779 Preprocessor1_Model1

We can visualize the difference in performance across classes with a confusion matrix.

collect_predictions(final_rs) %>%
  conf_mat(program_category, .pred_class) %>%
  autoplot()

We can also visualize the ROC curves for each class.

collect_predictions(final_rs) %>%
  roc_curve(truth = program_category, .pred_Finance:.pred_Micro) %>%
  ggplot(aes(1 - specificity, sensitivity, color = .level)) +
  geom_abline(slope = 1, color = "gray50", lty = 2, alpha = 0.8) +
  geom_path(size = 1.5, alpha = 0.7) +
  labs(color = NULL) +
  coord_fixed()

It looks like the finance and microeconomics papers were easier to identify than the macroeconomics papers.

Finally, we can extract (and save, if we like) the fitted workflow from our results to use for predicting on new data.

final_fitted <- extract_workflow(final_rs)
## can save this for prediction later with readr::write_rds()

predict(final_fitted, nber_test[111, ], type = "prob")

## # A tibble: 1 × 3
##   .pred_Finance `.pred_Macro/International` .pred_Micro
##           <dbl>                       <dbl>       <dbl>
## 1         0.104                       0.531       0.365

We can even make up new paper titles and see how our model classifies them.

predict(final_fitted, tibble(year = 2021, title = "Pricing Models for Corporate Responsibility"), type = "prob")

## # A tibble: 1 × 3
##   .pred_Finance `.pred_Macro/International` .pred_Micro
##           <dbl>                       <dbl>       <dbl>
## 1         0.598                       0.158       0.244

predict(final_fitted, tibble(year = 2021, title = "Teacher Health and Medicaid Expansion"), type = "prob")

## # A tibble: 1 × 3
##   .pred_Finance `.pred_Macro/International` .pred_Micro
##           <dbl>                       <dbl>       <dbl>
## 1         0.288                       0.141       0.571

To leave a comment for the author, please follow the link and comment on their blog: rstats | Julia Silge.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)