KNN vs. XGBoost Rivalry: Women Employment in Management

[This article was first published on DataGeeek, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Finding a high-profile job position has been very hard for women especially those living in countries with few opportunities for related acquires. This problem can be stemmed from many reasons like contextual factors and accessibility dimensions. This article will examine females’ high-profile job positions rate in the working environment worldwide from the aforementioned perspective.

I prefer the female share of employment in senior and middle management (%) indicator for our target variable modeled on it. The explanatory variables I chose and their explanations link are in the below code block. Besides those, the region and income variables have been included.

library(tidyverse)
library(WDI)

#Female share of employment in senior and middle management (%)
df_profile <- 
  WDI(indicator = "SL.EMP.SMGT.FE.ZS", extra = TRUE) %>% 
  as_tibble() %>% 
  rename(senior_middle_management = SL.EMP.SMGT.FE.ZS) %>% 
  crosstable::remove_labels()

#Employment in industry, female (% of female employment)
#(https://data.worldbank.org/indicator/SL.IND.EMPL.FE.ZS)
df_industry <- 
  WDI(indicator = "SL.IND.EMPL.FE.ZS", extra = TRUE) %>% 
  as_tibble() %>% 
  rename(industry_employment = SL.IND.EMPL.FE.ZS) %>% 
  crosstable::remove_labels()

#Women Business and the Law Index Score (scale 1-100)
#(https://data.worldbank.org/indicator/SG.LAW.INDX)
df_wbl <- 
  WDI(indicator = "SG.LAW.INDX", extra = TRUE) %>% 
  as_tibble() %>% 
  rename(wbl = SG.LAW.INDX) %>% 
  crosstable::remove_labels()

#Proportion of seats held by women in national parliaments (%)
#(https://data.worldbank.org/indicator/SG.GEN.PARL.ZS)
df_seats <- 
  WDI(indicator = "SG.GEN.PARL.ZS", extra = TRUE) %>% 
  as_tibble() %>% 
  rename(seats = SG.GEN.PARL.ZS) %>% 
  crosstable::remove_labels()
  

#Merging all the datasets
df_merged <- 
  reduce(list(df_profile,
              df_industry,
              df_wbl,
              df_seats),
         inner_join,
         by = c("country","year","region","income")) %>% 
  select(region, 
         income,
         wbl,
         seats,
         senior_middle_management,
         industry_employment) %>%
  mutate(across(wbl:industry_employment,  ~ round(., digits = 2))) %>% 
  drop_na()

Before the modeling, we will examine the distributions of the target variable by region and income level to understand its characteristics.

#Exploratory Data Analysis (EDA)
library(kableExtra)

#Female management by region
female_management_region <- 
  split(df_merged$senior_middle_management, 
        df_merged$region)

inline_plot_region <- 
  data.frame(var = c(df_merged$region %>% unique()), 
             Boxplot = "", 
             Histogram = "")


#Female management by income
female_management_income <- 
  split(df_merged$senior_middle_management, 
        df_merged$income)

inline_plot_income <- 
  data.frame(var = c(df_merged$income %>% unique()), 
             Boxplot = "", 
             Histogram = "")


#Kable table of the distributions of female share of employment 
#in senior and middle management rate by region and income
plot_df <- 
  inline_plot_region %>% 
  rbind(inline_plot_income)

plot_df %>% 
  kbl(
    col.names = c("", "Boxplot", "Histogram"),
    caption = "<center><b>The Distributions of Female Management (%)<br>by Region and Income</b></center>") %>%
  kable_paper("striped", full_width = F) %>%
  pack_rows("Region", 1, 7) %>%
  pack_rows("Income", 8, 11) %>% 
  column_spec(2, image = c(spec_boxplot(female_management_region),
                           spec_boxplot(female_management_income))) %>% 
  column_spec(3, image = c(spec_hist(female_management_region), 
                           spec_hist(female_management_income))) %>% 
  kable_classic(full_width = F, 
                html_font = "Bricolage Grotesque") %>% 
  kable_styling(position = "center")

When we look at the High income, we can see that it is right-skewed, indicating the majority of the observations are below the average. The same is applied to Latin America & Caribbean.

Now, we can pass the modeling phase; while doing that we will compare the algorithms below.

#Modeling
library(tidymodels)

##Split into a train and test set
set.seed(12345)
df_split <- initial_split(df_merged, 
                          prop = 0.9,
                          strata = region)

df_train <- training(df_split)
df_test <- testing(df_split)

#Resampling for comparing many models
set.seed(12345)
df_folds <- 
  vfold_cv(df_train, strata = region)

#Bayesian additive regression trees (BART)
spec_bart <- 
  parsnip::bart(trees = 20) %>%
  set_mode("regression") %>% 
  set_engine("dbarts")

#Boosted trees via xgboost
spec_boost <- 
  boost_tree(trees = 20) %>%
  set_mode("regression") %>% 
  set_engine("xgboost")

#K-nearest neighbors via kknn
spec_knn <- 
  nearest_neighbor(neighbors = 5, 
                   weight_func = "triangular") %>%
  set_mode("regression") %>%
  set_engine("kknn")
 
#Generalized additive models via mgcv
spec_gen_add <- 
  gen_additive_mod(select_features = FALSE, 
                   adjust_deg_free = 10) %>% 
  set_mode("regression") %>% 
  set_engine("mgcv")

#Linear regression via keras/tensorflow
spec_linreg <- 
  linear_reg(penalty = 0.1) %>% 
  set_engine("keras")

#Workflow sets

#Workflow set With only a basic formula
no_pre_proc <- 
  workflow_set(
    preproc = list("formula" = 
                     senior_middle_management ~ 
                     region + income + wbl + seats + industry_employment),
    models = list(BART = spec_bart)
  )


#Preprocessing with features
rec_features<- 
  recipe(senior_middle_management ~ ., data = df_train) %>% 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% 
  step_zv(all_predictors())

#Workflow sets With features
with_features <- 
  workflow_set(
    preproc = list(dummy = rec_features),
    models = list(
      Linear = spec_linreg,
      XGBoost = spec_boost,
      KNN = spec_knn
    )
  )

#Workflow set for GAM
wflwset_gam <- 
  workflow() %>% 
  add_model(spec_gen_add, 
            formula = senior_middle_management ~ 
              region + income + wbl + seats + industry_employment) %>% 
  add_formula(senior_middle_management ~ 
                region + 
                income +
                wbl + 
                seats + 
                industry_employment) %>% 
  as_workflow_set(GAM = .)

#Combining all workflow sets
all_wflws <- 
  bind_rows(no_pre_proc,
            with_features,
            wflwset_gam) %>% 
  mutate(wflow_id = gsub("(formula_)|(dummy_)", "", wflow_id))

#Evaluating the models
resamples_ctrl <-
  control_grid(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE
  )

mods_results <-
  all_wflws %>%
  workflow_map(
    seed = 98765,
    resamples = df_folds,
    control = resamples_ctrl
  )

#Accuracy ranking of the models
df_rmse <- 
  mods_results %>% 
  rank_results() %>%
  select(wflow_id, .metric, mean) %>% 
  filter(.metric == "rmse") %>% 
  select(Models = wflow_id, RMSE = mean)
  
df_rsq <- 
  mods_results %>% 
  rank_results() %>%
  select(wflow_id, .metric, mean) %>% 
  filter(.metric == "rsq") %>% 
  select(Models = wflow_id, RSQ = mean)


#Accuracy table
library(gt)

df_acc <- 
  df_rmse %>% 
  full_join(df_rsq, by = "Models")

df_acc %>% 
  #rounding decimal digits down to 2 for all numeric variables
  mutate(across(where(is.numeric), ~ round(., 2))) %>% 
  gt() %>% 
  data_color(
    method = "numeric",
    palette = c("red", "green")
  ) %>% 
  cols_align(align = "center", columns = -Models) %>% 
  #separating alignment of column names from cells-alignment
  tab_style(
    style = cell_text(align = "left"),
    locations = cells_body(
      columns = Models
    )) %>% 
  #separating cell body from each other
  tab_style(
    style = cell_borders(sides = "all", 
                         color = "white",
                         weight = px(12), 
                         style = "solid"),
    locations = cells_body(columns = everything())) %>% 
  tab_header(title = "Accuracy")%>% 
  opt_table_font(font = "Bricolage Grotesque")

To the above table, the K-nearest neighbors (KNN) algorithm has the best accuracy metrics. So we will build the explanatory model analysis on that model and the parameters.

#Variable importance
library(DALEXtra)

#Fitted workflow for KNN
set.seed(98765)

knn_wflow_fitted <- 
  workflow() %>% 
  add_recipe(rec_features) %>% 
  add_model(spec_knn) %>% 
  fit(df_train)

knn_wflow_fitted %>% 
  extract_fit_parsnip()

#Processed data frame for variable importance calculation
imp_data <- 
  rec_features %>% 
  prep() %>% 
  bake(new_data = NULL) 

#Explainer object
explainer_knn <- 
  explain_tidymodels(
    knn_wflow_fitted %>% extract_fit_parsnip(), 
    data = imp_data %>% select(-senior_middle_management), 
    y = imp_data$senior_middle_management,
    label = "",
    verbose = FALSE
  )


#Calculating permutation-based variable importance 
set.seed(1983)
vip_knn <- model_parts(explainer_knn, 
                       loss_function = loss_root_mean_square,
                       type = "difference",
                       B = 100,#the number of permutations
                       label = "")

plot(vip_knn)

#Variable importance plot
vip_knn %>% 
  mutate(variable = str_remove_all(variable, "region_|income_")) %>%
  #removes (...) and replacing with 'and' instead
  mutate(variable = str_replace_all(variable, "\\.{3}"," and ")) %>% 
  #removes (.) and replacing with blank space
  mutate(variable = str_replace_all(variable, "\\.", " ")) %>%
  mutate(variable = case_when(
    variable == "industry_employment" ~ "Industry Employment",
    variable == "seats" ~ "Parliaments Seats",
    variable == "wbl" ~ "WBL",
    TRUE ~ variable
  )) %>%
  plot() +
  labs(color = "",
       x = "",
       y = "",
       subtitle = "Higher indicates more important",
       title = "Factors Affecting Female Employment in Senior and Middle Management") +
  theme_minimal(base_family = "Bricolage Grotesque",
                base_size = 16) +
  theme(legend.position = "none",
        plot.title = element_text(hjust = 0.5, 
                                  size = 14,
                                  face = "bold"),
        plot.subtitle = element_text(hjust = 0.5, size = 12),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_blank(),
        plot.background = element_rect(fill = "#FFEBFE"))

As seen in the plot above, the rate of female participation in the industry is the most determinant factor followed by the rate of female parliamentarians and women, business, and law (WBL) score.

But, how do those factors affect our target variable? To answer this, we will calculate the partial dependence profiles.

#A function for plotting partial dependence profiles (PDP)
library(rlang)

plot_pdp <- function(var){
  
  #Partial dependence profiles
  set.seed(1983)
  pdp_obj <- model_profile(explainer_knn, 
                           N = NULL, #for all observations
                           variables = var)
  
  x_title <- 
    var %<>%
    #removes region_ or income_
    gsub("region_|income_", "", .) %>% 
    #removes (...) and replacing with 'and' instead
    gsub("\\.{3}"," and ", .) %>% 
    #removes (.) or (_) and replacing with blank space
    gsub("[._]", " ", .)
  
  pdp_obj$agr_profiles %>% 
    as_tibble() %>%
    ggplot(aes(`_x_`, `_yhat_`)) +
    geom_line(color = "#ffb8fb", 
              linewidth = 1.2) +
    labs(x = glue::glue('{x_title %>% str_to_title()}'),
         y = "Female Senior and Middle-Level Management (%)") +
    theme_minimal(base_family = "Bricolage Grotesque")
}

#Combining the plots
library(patchwork)

p_industry_employment <- plot_pdp("industry_employment")
p_parliaments_seats <- plot_pdp("seats")
p_wbl <- plot_pdp("wbl")
p_lower_middle_income <- plot_pdp("income_Lower.middle.income")

p_industry_employment + 
  p_parliaments_seats +
  p_wbl + 
  p_lower_middle_income + 
  plot_layout(nrow = 2, 
              axis_titles = "collect")

According to the above graphs, the effects of those factors are mostly positive but the interesting part is that the effect of WBL score turned positive after %80; also, the lower-middle-income level countries having a positive effect is a bit surprising.

To leave a comment for the author, please follow the link and comment on their blog: DataGeeek.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)