Twenty Questions and Decision Trees
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Contrary to my kids’ attempts, the strategy should be to ask questions that split the remaining possibilities roughly in half each time.
This seems very similar to a machine learning Decision Tree, although with an interesting distinction.
A decision tree cheats. The decision tree algorithm knows the answer (df$target = 1). The algorithm attempts to find the best feature and split value to separate df$target = 1 from df$target = 0 at each node, but it needs to know the right answer to ask the best questions. This is why, if the game is played say with different US Presidents multiple times, the algorithm may choose different features and split values.
Nevertheless, I thought it would be fun to program a decision tree model with the US Presidents. I found some data on Presidents. I decided some variables had too many values (high cardinality – there were a lot of political party names in the 1800s), so I grouped some values to reduce the number of unique values.
I initially began with a random integer between 1 and 47 to select a President, which selected President Hoover, but I found a different President would create a tree closer to the questions I would have asked if I were a player. So I selected President Reagan to get a more interesting tree.
(I considered selecting President Garfield to be able to ask the question, “Is the president credited with a unique proof of the Pythagorean Theorem?”, but I decided that was a little quirky, even for me.)
Here is the resulting tree for President Reagan:
Here is the resulting variable importance plot. Note that the variables are not in the same order as the tree splits. I understand that variable importance is based on the some of the improvements in all nodes where the variable was used as a splitter.
Here is my R code:
library(dplyr)
library(rpart)
library(rpart.plot)
library(ggplot2)
df <- read.csv("prez.csv", header=TRUE)
# data file available at github:
prez.csv
set.seed(123)
# r <- sample(1:nrow(df),1)
r <- 40 # deliberate choice to get longer tree
answer <- df$LastName[r]
print(paste("The target president is:", answer))
df$target <- rep(0, nrow(df))
df$target[r] <- 1
# Feature engineering:
df <- df %>%
# A. Categorical Reduction
mutate(
Party = case_when(
Party %in% c("Democratic") ~ "Democratic",
Party %in% c("Republican") ~ "Republican",
TRUE ~ "Other"),
Occupation = case_when(
Occupation %in% c("Businessman", "Lawyer") ~ Occupation,
TRUE ~ "Other"),
State = case_when(
State %in% c("New York") ~ "NY",
State %in% c("Ohio") ~ "OH",
State %in% c("Virginia") ~ "VA",
State %in% c("Massachusetts") ~ "MA",
State %in% c("Texas") ~ "TX", TRUE ~ "Other"),
Religion = case_when(
Religion %in% c("Episcopalian", "Presbyterian", "Unitarian", "Baptist", "Methodist") ~ "Main_Prot",
TRUE ~ "Other"),
# B. Year/Century Binning using cut()
DOB = cut(DOB, breaks = c(-Inf, 1800, 1900, 2000, Inf),
labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
DOD = cut(DOD, breaks = c(-Inf, 1800, 1900, 2000, Inf),
labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
Began = cut(Began, breaks = c(-Inf, 1800, 1900, 2000, Inf),
labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE),
Ended = cut(Ended, breaks = c(-Inf, 1800, 1900, 2000, Inf),
labels = c("18th century", "19th century", "20th century", "21st century"), right = FALSE)
) %>%
# C. Convert all new/existing binary/categorical variables to factor
mutate_at(vars(Party, State, Occupation, Religion, Assassinated, Military, Terms_GT_1, Pres_During_War, Was_Veep, DOB, DOD, Began, Ended), as.factor)
# selected variables
formula <- as.formula(target ~ Began + State + Party + Occupation + Pres_During_War)
# Using aggressive control settings to force a maximal, unpruned tree
prez_tree <- rpart(formula, data = df, method = "class",
control = rpart.control(cp = 0.001, minsplit = 2, minbucket = 1))
rpart.plot(prez_tree, type = 4, extra = 101, main = "President Twenty Questions")
# check Reagan
df %>% filter(Began == "20th century" &
!State %in% c("MA", "NY", "OH", "TX") &
Party == "Republican" &
!Occupation %in% c( "Businessman", "Lawyer"))
variable_importance <- prez_tree$variable.importance
importance_df <- data.frame(
Variable = names(variable_importance),
Importance = variable_importance
)
importance_df <- importance_df[order(importance_df$Importance, decreasing = TRUE), ]
common_theme <- theme(
legend.position="NULL",
plot.title = element_text(size=15, face="bold"),
plot.subtitle = element_text(size=12.5, face="bold"),
axis.title = element_text(size=15, face="bold"),
axis.text = element_text(size=15, face="bold"),
legend.title = element_text(size=15, face="bold"),
legend.text = element_text(size=15, face="bold"))
ggplot(importance_df, aes(x = factor(Variable, levels = rev(Variable)), y = Importance)) +
geom_col(aes(fill = Variable)) +
coord_flip() +
labs(title = "20 Questions Variable Importance",
x = "Variable",
y = "Mean Decrease Gini") +
common_theme
# loop through all presidents to see different first split vars
library(purrr)
### 1. Define the Analysis Function ###
# The function is modified to return a data frame row for clarity
get_first_split_row <- function(df, r) {
# Temporarily set the target for the current president
df$target <- 0
df$target[r] <- 1
tree <- rpart(formula, data = df, method = "class",
control = rpart.control(cp = 0.001, minsplit = 2, minbucket = 1))
frame <- tree$frame
# Determine the result
if (nrow(frame) > 1) {
first_split_var <- as.character(frame$var[1])
} else {
first_split_var <- "No split"
}
# Return a single-row data frame
return(data.frame(
President = df$LastName[r],
First_Split_Variable = first_split_var
))
}
### 2. Run the Analysis and Combine Results ###
# Create a list of row indices to iterate over
indices_to_run <- 1:nrow(df)
# Use map_dfr to apply the function to every index and combine the results
# into a single data frame (dfr = data frame row bind)
first_split_results_df <- map_dfr(indices_to_run, ~ get_first_split_row(df, .x))
### 3. Display the Table and Original Analysis ###
# Display the resulting table:
print(first_split_results_df)
print(table(first_split_results_df$First_Split_Variable))
End
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
