CHAID and caret – a good combo – June 6, 2018

[This article was first published on Chuck Powell, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

In an earlier post I focused on
an in depth visit with CHAID (Chi-square automatic interaction
detection). There are lots of tools that can help you predict an
outcome, or classify, but CHAID is especially good at helping you
explain to any audience how the model arrives at it’s prediction or
classification. It’s also incredibly robust from a statistical
perspective, making almost no assumptions about your data for
distribution or normality. This post I’ll focus on marrying CHAID with
the awesome caret package
to make our predicting easier and hopefully more accurate. Although not
strictly necessary you’re probably best served by reading the original
post first.

We’ve been using a dataset that comes to us from the IBM Watson
Project

and comes packaged with the rsample library. It’s a very practical and
understandable dataset. A great use case for a tree based algorithm.
Imagine yourself in a fictional company faced with the task of trying to
figure out which employees you are going to “lose” a.k.a. attrition or
turnover. There’s a steep cost involved in keeping good employees, and
training and on-boarding can be expensive. Being able to predict
attrition even a little bit better would save you lots of money and make
the company better, especially if you can understand exactly what you
have to “watch out for” that might indicate the person is a high risk to
leave.

Setup and library loading

If you’ve never used CHAID before you may also not have partykit.
CHAID isn’t on CRAN but I have commented out the install command
below. You’ll also get a variety of messages, none of which is relevant
to this example so I’ve suppressed them.

<span class="c1"># install.packages("partykit")</span><span class="w">
</span><span class="c1"># install.packages("CHAID", repos="http://R-Forge.R-project.org")</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">rsample</span><span class="p">)</span><span class="w"> </span><span class="c1"># for dataset and splitting also loads broom and tidyr</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">dplyr</span><span class="p">)</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">CHAID</span><span class="p">)</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">purrr</span><span class="p">)</span><span class="w"> </span><span class="c1"># we'll use it to consolidate some data</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">caret</span><span class="p">)</span><span class="w">
</span><span class="n">require</span><span class="p">(</span><span class="n">kableExtra</span><span class="p">)</span><span class="w"> </span><span class="c1"># just to make the output nicer</span><span class="w">
</span>

Predicting attrition in a fictional company

Last time I spent a great deal
of time explaining the mechanics of loading the data. This time we’ll
race right through. If you need an explanation of what’s going on please
refer back. I’ve embedded some comments in the code to follow along and
changing the data frame name to newattrit is not strictly necessary it
just mimics the last post.

<span class="n">str</span><span class="p">(</span><span class="n">attrition</span><span class="p">)</span><span class="w"> </span><span class="c1"># included in rsample</span><span class="w">
</span>
## 'data.frame':    1470 obs. of  31 variables:
##  $ Age                     : int  41 49 37 33 27 32 59 30 38 36 ...
##  $ Attrition               : Factor w/ 2 levels "No","Yes": 2 1 2 1 1 1 1 1 1 1 ...
##  $ BusinessTravel          : Factor w/ 3 levels "Non-Travel","Travel_Frequently",..: 3 2 3 2 3 2 3 3 2 3 ...
##  $ DailyRate               : int  1102 279 1373 1392 591 1005 1324 1358 216 1299 ...
##  $ Department              : Factor w/ 3 levels "Human_Resources",..: 3 2 2 2 2 2 2 2 2 2 ...
##  $ DistanceFromHome        : int  1 8 2 3 2 2 3 24 23 27 ...
##  $ Education               : Ord.factor w/ 5 levels "Below_College"<..: 2 1 2 4 1 2 3 1 3 3 ...
##  $ EducationField          : Factor w/ 6 levels "Human_Resources",..: 2 2 5 2 4 2 4 2 2 4 ...
##  $ EnvironmentSatisfaction : Ord.factor w/ 4 levels "Low"<"Medium"<..: 2 3 4 4 1 4 3 4 4 3 ...
##  $ Gender                  : Factor w/ 2 levels "Female","Male": 1 2 2 1 2 2 1 2 2 2 ...
##  $ HourlyRate              : int  94 61 92 56 40 79 81 67 44 94 ...
##  $ JobInvolvement          : Ord.factor w/ 4 levels "Low"<"Medium"<..: 3 2 2 3 3 3 4 3 2 3 ...
##  $ JobLevel                : int  2 2 1 1 1 1 1 1 3 2 ...
##  $ JobRole                 : Factor w/ 9 levels "Healthcare_Representative",..: 8 7 3 7 3 3 3 3 5 1 ...
##  $ JobSatisfaction         : Ord.factor w/ 4 levels "Low"<"Medium"<..: 4 2 3 3 2 4 1 3 3 3 ...
##  $ MaritalStatus           : Factor w/ 3 levels "Divorced","Married",..: 3 2 3 2 2 3 2 1 3 2 ...
##  $ MonthlyIncome           : int  5993 5130 2090 2909 3468 3068 2670 2693 9526 5237 ...
##  $ MonthlyRate             : int  19479 24907 2396 23159 16632 11864 9964 13335 8787 16577 ...
##  $ NumCompaniesWorked      : int  8 1 6 1 9 0 4 1 0 6 ...
##  $ OverTime                : Factor w/ 2 levels "No","Yes": 2 1 2 2 1 1 2 1 1 1 ...
##  $ PercentSalaryHike       : int  11 23 15 11 12 13 20 22 21 13 ...
##  $ PerformanceRating       : Ord.factor w/ 4 levels "Low"<"Good"<"Excellent"<..: 3 4 3 3 3 3 4 4 4 3 ...
##  $ RelationshipSatisfaction: Ord.factor w/ 4 levels "Low"<"Medium"<..: 1 4 2 3 4 3 1 2 2 2 ...
##  $ StockOptionLevel        : int  0 1 0 0 1 0 3 1 0 2 ...
##  $ TotalWorkingYears       : int  8 10 7 8 6 8 12 1 10 17 ...
##  $ TrainingTimesLastYear   : int  0 3 3 3 3 2 3 2 2 3 ...
##  $ WorkLifeBalance         : Ord.factor w/ 4 levels "Bad"<"Good"<"Better"<..: 1 3 3 3 3 2 2 3 3 2 ...
##  $ YearsAtCompany          : int  6 10 0 8 2 7 1 1 9 7 ...
##  $ YearsInCurrentRole      : int  4 7 0 7 2 7 0 0 7 7 ...
##  $ YearsSinceLastPromotion : int  0 1 0 3 2 3 0 0 1 7 ...
##  $ YearsWithCurrManager    : int  5 7 0 0 2 6 0 0 8 7 ...
<span class="c1"># the easy to convert because they are integers with less than 10 levels</span><span class="w">
</span><span class="n">attrition</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">attrition</span><span class="w"> </span><span class="o">%>%</span><span class="w"> 
  </span><span class="n">mutate_if</span><span class="p">(</span><span class="k">function</span><span class="p">(</span><span class="n">col</span><span class="p">)</span><span class="w"> </span><span class="nf">length</span><span class="p">(</span><span class="n">unique</span><span class="p">(</span><span class="n">col</span><span class="p">))</span><span class="w"> </span><span class="o"><=</span><span class="w"> </span><span class="m">10</span><span class="w"> </span><span class="o">&</span><span class="w"> </span><span class="nf">is.integer</span><span class="p">(</span><span class="n">col</span><span class="p">),</span><span class="w"> </span><span class="n">as.factor</span><span class="p">)</span><span class="w">

</span><span class="c1"># More difficult to get 5 levels</span><span class="w">
</span><span class="n">attrition</span><span class="o">$</span><span class="n">YearsSinceLastPromotion</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">cut</span><span class="p">(</span><span class="w">
  </span><span class="n">attrition</span><span class="o">$</span><span class="n">YearsSinceLastPromotion</span><span class="p">,</span><span class="w">
  </span><span class="n">breaks</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">-1</span><span class="p">,</span><span class="w"> </span><span class="m">0.9</span><span class="p">,</span><span class="w"> </span><span class="m">1.9</span><span class="p">,</span><span class="w"> </span><span class="m">2.9</span><span class="p">,</span><span class="w"> </span><span class="m">30</span><span class="p">),</span><span class="w">
  </span><span class="n">labels</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"Less than 1"</span><span class="p">,</span><span class="w"> </span><span class="s2">"1"</span><span class="p">,</span><span class="w"> </span><span class="s2">"2"</span><span class="p">,</span><span class="w"> </span><span class="s2">"More than 2"</span><span class="p">)</span><span class="w">
</span><span class="p">)</span><span class="w">

</span><span class="c1"># everything else just five more or less even levels</span><span class="w">
</span><span class="n">attrition</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">attrition</span><span class="w"> </span><span class="o">%>%</span><span class="w"> 
  </span><span class="n">mutate_if</span><span class="p">(</span><span class="n">is.numeric</span><span class="p">,</span><span class="w"> </span><span class="n">funs</span><span class="p">(</span><span class="n">cut_number</span><span class="p">(</span><span class="n">.</span><span class="p">,</span><span class="w"> </span><span class="n">n</span><span class="o">=</span><span class="m">5</span><span class="p">)))</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">attrition</span><span class="p">)</span><span class="w">
</span>
## [1] 1470   31
<span class="n">str</span><span class="p">(</span><span class="n">attrition</span><span class="p">)</span><span class="w"> 
</span>
## 'data.frame':    1470 obs. of  31 variables:
##  $ Age                     : Factor w/ 5 levels "[18,29]","(29,34]",..: 4 5 3 2 1 2 5 2 3 3 ...
##  $ Attrition               : Factor w/ 2 levels "No","Yes": 2 1 2 1 1 1 1 1 1 1 ...
##  $ BusinessTravel          : Factor w/ 3 levels "Non-Travel","Travel_Frequently",..: 3 2 3 2 3 2 3 3 2 3 ...
##  $ DailyRate               : Factor w/ 5 levels "[102,392]","(392,656]",..: 4 1 5 5 2 4 5 5 1 5 ...
##  $ Department              : Factor w/ 3 levels "Human_Resources",..: 3 2 2 2 2 2 2 2 2 2 ...
##  $ DistanceFromHome        : Factor w/ 5 levels "[1,2]","(2,5]",..: 1 3 1 2 1 1 2 5 5 5 ...
##  $ Education               : Ord.factor w/ 5 levels "Below_College"<..: 2 1 2 4 1 2 3 1 3 3 ...
##  $ EducationField          : Factor w/ 6 levels "Human_Resources",..: 2 2 5 2 4 2 4 2 2 4 ...
##  $ EnvironmentSatisfaction : Ord.factor w/ 4 levels "Low"<"Medium"<..: 2 3 4 4 1 4 3 4 4 3 ...
##  $ Gender                  : Factor w/ 2 levels "Female","Male": 1 2 2 1 2 2 1 2 2 2 ...
##  $ HourlyRate              : Factor w/ 5 levels "[30,45]","(45,59]",..: 5 3 5 2 1 4 4 3 1 5 ...
##  $ JobInvolvement          : Ord.factor w/ 4 levels "Low"<"Medium"<..: 3 2 2 3 3 3 4 3 2 3 ...
##  $ JobLevel                : Factor w/ 5 levels "1","2","3","4",..: 2 2 1 1 1 1 1 1 3 2 ...
##  $ JobRole                 : Factor w/ 9 levels "Healthcare_Representative",..: 8 7 3 7 3 3 3 3 5 1 ...
##  $ JobSatisfaction         : Ord.factor w/ 4 levels "Low"<"Medium"<..: 4 2 3 3 2 4 1 3 3 3 ...
##  $ MaritalStatus           : Factor w/ 3 levels "Divorced","Married",..: 3 2 3 2 2 3 2 1 3 2 ...
##  $ MonthlyIncome           : Factor w/ 5 levels "[1.01e+03,2.7e+03]",..: 4 3 1 2 2 2 1 1 4 3 ...
##  $ MonthlyRate             : Factor w/ 5 levels "[2.09e+03,6.89e+03]",..: 4 5 1 5 3 3 2 3 2 3 ...
##  $ NumCompaniesWorked      : Factor w/ 10 levels "0","1","2","3",..: 9 2 7 2 10 1 5 2 1 7 ...
##  $ OverTime                : Factor w/ 2 levels "No","Yes": 2 1 2 2 1 1 2 1 1 1 ...
##  $ PercentSalaryHike       : Factor w/ 5 levels "[11,12]","(12,13]",..: 1 5 3 1 1 2 5 5 5 2 ...
##  $ PerformanceRating       : Ord.factor w/ 4 levels "Low"<"Good"<"Excellent"<..: 3 4 3 3 3 3 4 4 4 3 ...
##  $ RelationshipSatisfaction: Ord.factor w/ 4 levels "Low"<"Medium"<..: 1 4 2 3 4 3 1 2 2 2 ...
##  $ StockOptionLevel        : Factor w/ 4 levels "0","1","2","3": 1 2 1 1 2 1 4 2 1 3 ...
##  $ TotalWorkingYears       : Factor w/ 5 levels "[0,5]","(5,8]",..: 2 3 2 2 2 2 4 1 3 4 ...
##  $ TrainingTimesLastYear   : Factor w/ 7 levels "0","1","2","3",..: 1 4 4 4 4 3 4 3 3 4 ...
##  $ WorkLifeBalance         : Ord.factor w/ 4 levels "Bad"<"Good"<"Better"<..: 1 3 3 3 3 2 2 3 3 2 ...
##  $ YearsAtCompany          : Factor w/ 5 levels "[0,2]","(2,5]",..: 3 4 1 4 1 3 1 1 4 3 ...
##  $ YearsInCurrentRole      : Factor w/ 5 levels "[0,1]","(1,2]",..: 3 4 1 4 2 4 1 1 4 4 ...
##  $ YearsSinceLastPromotion : Factor w/ 4 levels "Less than 1",..: 1 2 1 4 3 4 1 1 2 4 ...
##  $ YearsWithCurrManager    : Factor w/ 5 levels "[0,1]","(1,2]",..: 4 4 1 1 2 4 1 1 5 4 ...
<span class="n">newattrit</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">attrition</span><span class="w"> </span><span class="o">%>%</span><span class="w"> 
  </span><span class="n">select_if</span><span class="p">(</span><span class="n">is.factor</span><span class="p">)</span><span class="w">
</span><span class="nf">dim</span><span class="p">(</span><span class="n">newattrit</span><span class="p">)</span><span class="w">
</span>
## [1] 1470   31

Okay we have data on 1,470 employees. We have 30 potential predictor
(features) or independent variables and the all important attrition
variable which gives us a yes or no answer to the question of whether or
not the employee left. We’re to build the most accurate predictive model
we can that is also simple (parsimonious) and explainable. The
predictors we have seem to be the sorts of data we might have on hand in
our HR files and thank goodness are labelled in a way that makes them
pretty self explanatory.

Last post we explored the control options and built predictive models
like the one below. For a review of what the output means and how CHAID
works please refer back
.

<span class="c1"># explore the control options</span><span class="w">
</span><span class="n">ctrl</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">chaid_control</span><span class="p">(</span><span class="n">minsplit</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">200</span><span class="p">,</span><span class="w"> </span><span class="n">minprob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">0.05</span><span class="p">)</span><span class="w">
</span><span class="n">ctrl</span><span class="w">
</span>
## $alpha2
## [1] 0.05
## 
## $alpha3
## [1] -1
## 
## $alpha4
## [1] 0.05
## 
## $minsplit
## [1] 200
## 
## $minbucket
## [1] 7
## 
## $minprob
## [1] 0.05
## 
## $stump
## [1] FALSE
## 
## $maxheight
## [1] -1
## 
## attr(,"class")
## [1] "chaid_control"
<span class="n">full_data</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">chaid</span><span class="p">(</span><span class="n">Attrition</span><span class="w"> </span><span class="o">~</span><span class="w"> </span><span class="n">.</span><span class="p">,</span><span class="w"> </span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">newattrit</span><span class="p">,</span><span class="w"> </span><span class="n">control</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ctrl</span><span class="p">)</span><span class="w">
</span><span class="n">print</span><span class="p">(</span><span class="n">full_data</span><span class="p">)</span><span class="w">
</span>
## 
## Model formula:
## Attrition ~ Age + BusinessTravel + DailyRate + Department + DistanceFromHome + 
##     Education + EducationField + EnvironmentSatisfaction + Gender + 
##     HourlyRate + JobInvolvement + JobLevel + JobRole + JobSatisfaction + 
##     MaritalStatus + MonthlyIncome + MonthlyRate + NumCompaniesWorked + 
##     OverTime + PercentSalaryHike + PerformanceRating + RelationshipSatisfaction + 
##     StockOptionLevel + TotalWorkingYears + TrainingTimesLastYear + 
##     WorkLifeBalance + YearsAtCompany + YearsInCurrentRole + YearsSinceLastPromotion + 
##     YearsWithCurrManager
## 
## Fitted party:
## [1] root
## |   [2] OverTime in No
## |   |   [3] YearsAtCompany in [0,2]
## |   |   |   [4] Age in [18,29], (29,34]: No (n = 129, err = 32.6%)
## |   |   |   [5] Age in (34,38], (38,45], (45,60]: No (n = 109, err = 6.4%)
## |   |   [6] YearsAtCompany in (2,5], (5,7], (7,10], (10,40]
## |   |   |   [7] WorkLifeBalance in Bad: No (n = 45, err = 22.2%)
## |   |   |   [8] WorkLifeBalance in Good, Better, Best
## |   |   |   |   [9] JobSatisfaction in Low: No (n = 153, err = 12.4%)
## |   |   |   |   [10] JobSatisfaction in Medium, High, Very_High
## |   |   |   |   |   [11] Age in [18,29], (29,34], (34,38], (38,45]
## |   |   |   |   |   |   [12] BusinessTravel in Non-Travel, Travel_Rarely
## |   |   |   |   |   |   |   [13] JobInvolvement in Low: No (n = 25, err = 12.0%)
## |   |   |   |   |   |   |   [14] JobInvolvement in Medium, High, Very_High
## |   |   |   |   |   |   |   |   [15] RelationshipSatisfaction in Low: No (n = 81, err = 3.7%)
## |   |   |   |   |   |   |   |   [16] RelationshipSatisfaction in Medium, High: No (n = 198, err = 0.0%)
## |   |   |   |   |   |   |   |   [17] RelationshipSatisfaction in Very_High: No (n = 105, err = 4.8%)
## |   |   |   |   |   |   [18] BusinessTravel in Travel_Frequently: No (n = 95, err = 8.4%)
## |   |   |   |   |   [19] Age in (45,60]: No (n = 114, err = 11.4%)
## |   [20] OverTime in Yes
## |   |   [21] JobLevel in 1: Yes (n = 156, err = 47.4%)
## |   |   [22] JobLevel in 2, 3, 4, 5
## |   |   |   [23] MaritalStatus in Divorced, Married: No (n = 188, err = 10.6%)
## |   |   |   [24] MaritalStatus in Single: No (n = 72, err = 34.7%)
## 
## Number of inner nodes:    11
## Number of terminal nodes: 13
<span class="n">plot</span><span class="p">(</span><span class="w">
  </span><span class="n">full_data</span><span class="p">,</span><span class="w">
  </span><span class="n">main</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"newattrit dataset, minsplit = 200, minprob = 0.05"</span><span class="p">,</span><span class="w">
  </span><span class="n">gp</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">gpar</span><span class="p">(</span><span class="w">
    </span><span class="n">lty</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"solid"</span><span class="p">,</span><span class="w">
    </span><span class="n">lwd</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">2</span><span class="p">,</span><span class="w">
    </span><span class="n">fontsize</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="w">
  </span><span class="p">)</span><span class="w">
</span><span class="p">)</span><span class="w">
</span>

Over-fitting

Okay we have a working predictive model. At this point, however, we’ve
been cheating to a certain degree! We’ve been using every available
piece of data we have to develop the best possible model. We’ve told the
powerful all-knowing algorithims to squeeze every last bit of accuracy
they can out of the data. We’ve told it to fit the best possible
model. Problem is that we may have done that at the cost of being able
to generalize our model to new data or to new situations. That’s the
problem of over-fitting in a nutshell. If you want a fuller
understanding please consider reading this post on
EliteDataScience
.
I’m going to move on to a solution for solving this limitation and
that’s where caret comes in.

We’re going to use caret to employ cross-validation a.k.a. cv to
solve this challenge for us, or more accurately to mitigate the problem.
The same article
explains

it well so I won’t repeat that explanation here, I’ll simply show you
how to run the steps in R.

This is also a good time to point out that caret has extraordinarily
comprehensive documentation
which I
used extensively and I’m limiting myself to the basics.

As a first step, let’s just take 30% of our data and put is aside for a
minute. We’re not going to let chaid see it or know about it as we
build the model. In some scenarios you have subsequent data at hand for
checking your model (data from another company or another year or …). We
don’t, so we’re going to self-impose this restraint. Why 30%? Doesn’t
have to be, could be as low as 20% or as high as 40% it really depends
on how conservative you want to be, and how much data you have at hand.
Since this is just a tutorial we’ll simply use 30% as a representative
number. We’ve already loaded both rsample and caret either of which
is quite capable of making this split for us. I’m arbitrarily going to
use rsample syntax which is the line with initial_split(newattrit,
prop = .7, strata = "Attrition")
in it. That takes our data set
newattrit makes a 70% split ensuring that we keep our outcome variable
Attrition as close to 70/30 as we can. This is important because our
data is already pretty lop-sided
for outcomes. The two subsequent lines
serve to take the data contained in split and produce two separate
dataframes, test and train. They have 440 and 1030 staff members
each. We’ll set test aside for now and focus on train.

<span class="c1"># Create training (70%) and test (30%) sets for the attrition data.</span><span class="w">
</span><span class="c1"># Use set.seed for reproducibility</span><span class="w">
</span><span class="c1">#####</span><span class="w">
</span><span class="n">set.seed</span><span class="p">(</span><span class="m">1234</span><span class="p">)</span><span class="w">
</span><span class="n">split</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">initial_split</span><span class="p">(</span><span class="n">newattrit</span><span class="p">,</span><span class="w"> </span><span class="n">prop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">.7</span><span class="p">,</span><span class="w"> </span><span class="n">strata</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Attrition"</span><span class="p">)</span><span class="w">
</span><span class="n">train</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">training</span><span class="p">(</span><span class="n">split</span><span class="p">)</span><span class="w">
</span><span class="n">test</span><span class="w">  </span><span class="o"><-</span><span class="w"> </span><span class="n">testing</span><span class="p">(</span><span class="n">split</span><span class="p">)</span><span class="w">
</span>

The next step is a little counter-intuitive but quite practical. Turns
out that many models do not perform well when you feed them a formula
for the model even if they claim to support a formula interface (as
CHAID does). Here’s an SO
link

that discusses in detail but my suggestion to you is to always separate
them and avoid the problem altogether. We’re just taking our
predictors or features and putting them in x while we put our
outcome in y.

<span class="c1"># create response and feature data</span><span class="w">
</span><span class="n">features</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">setdiff</span><span class="p">(</span><span class="nf">names</span><span class="p">(</span><span class="n">train</span><span class="p">),</span><span class="w"> </span><span class="s2">"Attrition"</span><span class="p">)</span><span class="w">
</span><span class="n">x</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="p">[,</span><span class="w"> </span><span class="n">features</span><span class="p">]</span><span class="w">
</span><span class="n">y</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="o">$</span><span class="n">Attrition</span><span class="w">
</span>

Alright, let’s get back on track. trainControl is the function within
caret we need to use. Chapter 5 in the caret doco covers it in great
detail. I’m simply going to pluck out a few sane and safe options.
method = "cv" gets us cross-validation. number = 10 is pretty
obvious. I happen to like seeing the progress in case I want to go for
coffee so verboseIter = TRUE, and I play it safe and explicitly save
my predictions savePredictions = "final". We put everything in
train_control which we’ll use in a minute.

<span class="c1"># set up 10-fold cross validation procedure</span><span class="w">
</span><span class="n">train_control</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">trainControl</span><span class="p">(</span><span class="n">method</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"cv"</span><span class="p">,</span><span class="w">
                              </span><span class="n">number</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">,</span><span class="w">
                              </span><span class="n">verboseIter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">TRUE</span><span class="p">,</span><span class="w">
                              </span><span class="n">savePredictions</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"final"</span><span class="p">)</span><span class="w">
</span>

Not surprisingly the train function in caret trains our model! It
wants to know what our x and y’s are, as well as our training
control parameters which we’ve parked in train_control. At this point
we could successfully unleash the dogs of war (sorry Shakespeare) and
train our model since we know we want to use chaid. But let’s change
one other useful thing and that is metric which is what metric we want
to use to pick the “best” model. Instead of the default “accuracy” we’ll
use Kappa which as you may remember from the last post is more
conservative measure of how well we did.

If you’re running this code yourself this is a good time to take a
coffee break. I’ll tell you later how to find out how long it took
more or less exactly. But there’s no getting around it we’re model
building many more times so it takes longer.

<span class="c1"># train model</span><span class="w">
</span><span class="n">chaid.m1</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="p">(</span><span class="w">
  </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w">
  </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w">
  </span><span class="n">method</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"chaid"</span><span class="p">,</span><span class="w">
  </span><span class="n">metric</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Kappa"</span><span class="p">,</span><span class="w">
  </span><span class="n">trControl</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_control</span><span class="w">
</span><span class="p">)</span><span class="w">
</span>
## + Fold01: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold01: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold01: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold01: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold01: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold01: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold02: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold02: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold02: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold02: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold02: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold02: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold03: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold03: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold03: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold03: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold03: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold03: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold04: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold04: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold04: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold04: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold04: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold04: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold05: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold05: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold05: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold05: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold05: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold05: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold06: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold06: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold06: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold06: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold06: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold06: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold07: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold07: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold07: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold07: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold07: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold07: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold08: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold08: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold08: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold08: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold08: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold08: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold09: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold09: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold09: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold09: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold09: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold09: alpha2=0.01, alpha3=-1, alpha4=0.01 
## + Fold10: alpha2=0.05, alpha3=-1, alpha4=0.05 
## - Fold10: alpha2=0.05, alpha3=-1, alpha4=0.05 
## + Fold10: alpha2=0.03, alpha3=-1, alpha4=0.03 
## - Fold10: alpha2=0.03, alpha3=-1, alpha4=0.03 
## + Fold10: alpha2=0.01, alpha3=-1, alpha4=0.01 
## - Fold10: alpha2=0.01, alpha3=-1, alpha4=0.01 
## Aggregating results
## Selecting tuning parameters
## Fitting alpha2 = 0.05, alpha3 = -1, alpha4 = 0.05 on full training set

And…. we’re done. Turns out in this case the best solution was what
chaid uses as defaults. The very last line of the output tells us
that. But let’s use what we have used in the past for printing and
plotting the results…

<span class="n">chaid.m1</span><span class="w"> </span><span class="c1">#equivalent to print(chaid.m1)</span><span class="w">
</span>
## CHi-squared Automated Interaction Detection 
## 
## 1030 samples
##   30 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 928, 927, 927, 926, 928, 926, ... 
## Resampling results across tuning parameters:
## 
##   alpha2  alpha4  Accuracy   Kappa    
##   0.01    0.01    0.8223292  0.1522392
##   0.03    0.03    0.8349699  0.1579585
##   0.05    0.05    0.8213958  0.1692826
## 
## Tuning parameter 'alpha3' was held constant at a value of -1
## Kappa was used to select the optimal model using the largest value.
## The final values used for the model were alpha2 = 0.05, alpha3 = -1
##  and alpha4 = 0.05.
<span class="n">plot</span><span class="p">(</span><span class="n">chaid.m1</span><span class="p">)</span><span class="w">
</span>

Wait. What? These are not the output we’re used to. caret has changed
the output from its’ work (an improvement actually) but we’ll have to
change how we get the information out. Before we do that however, let’s
inspect what we have so far. The output gives us a nice concise summary.
1030 cases with 30 predictors. It gives us an idea of how many of the
1030 cases were used in the individual folds Summary of sample
sizes: 928, 927, 927, 926, 928, 926, ...
.

The bit about alpha2, alpha4, and alpha3 is somewhat mysterious.
We saw those names when we looked at the chaid_control documentation
last post but why are they here? We’ll come back to that in a moment.
But it is clear that it thought Kappa of 0.1692826 was best.

The plot isn’t what we’re used to seeing, but is easy to understand.
Kappa is on the y axis, alpha2 on the x axis and it’s shaded/colored
by alpha4 (remember we left alpha3 out of the mix). The plot is a
bit of overkill for what we did but we’ll put it to better use later.

But what about the things we were used to seeing? Well if you remember
that caret is reporting averages of all the folds it sort of makes
sense that the best final model results are now in
chaid.m1$finalModel so we need to use that when we print or plot.
So in the next block of code let’s:

  1. Print the final model from chaid (chaid.m1$finalModel)
  2. Plot the final model from chaid (plot(chaid.m1$finalModel))
  3. Produce the confusionMatrix across all folds
    (confusionMatrix(chaid.m1))
  4. Produce the confusionMatrix using the final model
    (confusionMatrix(predict(chaid.m1), y))
  5. Check on variable importance (varImp(chaid.m1))
  6. The best tuning parameters are stored in chaid.m1$bestTune
  7. How long did it take? Look in chaid.m1$times
  8. In case you forgot what method you used look here chaid.m1$method
  9. We’ll look at model info in a bit chaid.m1$modelInfo
  10. The summarized results are here in a nice format if needed later
    chaid.m1$results

Many of these you’ll never need but I wanted to at least give you a hint
of how complete the chaid.m1 object is

<span class="n">chaid.m1</span><span class="o">$</span><span class="n">finalModel</span><span class="w">
</span>
## 
## Model formula:
## .outcome ~ Age + BusinessTravel + DailyRate + Department + DistanceFromHome + 
##     Education + EducationField + EnvironmentSatisfaction + Gender + 
##     HourlyRate + JobInvolvement + JobLevel + JobRole + JobSatisfaction + 
##     MaritalStatus + MonthlyIncome + MonthlyRate + NumCompaniesWorked + 
##     OverTime + PercentSalaryHike + PerformanceRating + RelationshipSatisfaction + 
##     StockOptionLevel + TotalWorkingYears + TrainingTimesLastYear + 
##     WorkLifeBalance + YearsAtCompany + YearsInCurrentRole + YearsSinceLastPromotion + 
##     YearsWithCurrManager
## 
## Fitted party:
## [1] root
## |   [2] OverTime in No
## |   |   [3] YearsAtCompany in [0,2]
## |   |   |   [4] Age in [18,29], (29,34]
## |   |   |   |   [5] StockOptionLevel in 0: No (n = 43, err = 48.8%)
## |   |   |   |   [6] StockOptionLevel in 1, 2, 3
## |   |   |   |   |   [7] RelationshipSatisfaction in Low: Yes (n = 7, err = 42.9%)
## |   |   |   |   |   [8] RelationshipSatisfaction in Medium, High, Very_High: No (n = 38, err = 7.9%)
## |   |   |   [9] Age in (34,38], (38,45], (45,60]: No (n = 77, err = 7.8%)
## |   |   [10] YearsAtCompany in (2,5], (5,7], (7,10], (10,40]
## |   |   |   [11] WorkLifeBalance in Bad: No (n = 36, err = 19.4%)
## |   |   |   [12] WorkLifeBalance in Good, Better, Best
## |   |   |   |   [13] Department in Human_Resources, Sales
## |   |   |   |   |   [14] Age in [18,29], (29,34], (34,38], (38,45]
## |   |   |   |   |   |   [15] WorkLifeBalance in Bad, Good: No (n = 37, err = 16.2%)
## |   |   |   |   |   |   [16] WorkLifeBalance in Better, Best: No (n = 119, err = 4.2%)
## |   |   |   |   |   [17] Age in (45,60]: No (n = 27, err = 25.9%)
## |   |   |   |   [18] Department in Research_Development: No (n = 347, err = 4.0%)
## |   [19] OverTime in Yes
## |   |   [20] JobLevel in 1
## |   |   |   [21] JobRole in Healthcare_Representative, Human_Resources, Laboratory_Technician, Manager, Manufacturing_Director, Research_Director, Sales_Executive, Sales_Representative
## |   |   |   |   [22] JobInvolvement in Low, Medium: Yes (n = 19, err = 10.5%)
## |   |   |   |   [23] JobInvolvement in High, Very_High: Yes (n = 45, err = 44.4%)
## |   |   |   [24] JobRole in Research_Scientist: No (n = 53, err = 35.8%)
## |   |   [25] JobLevel in 2, 3, 4, 5
## |   |   |   [26] Gender in Female: No (n = 86, err = 9.3%)
## |   |   |   [27] Gender in Male
## |   |   |   |   [28] MaritalStatus in Divorced, Married: No (n = 71, err = 18.3%)
## |   |   |   |   [29] MaritalStatus in Single: No (n = 25, err = 44.0%)
## 
## Number of inner nodes:    14
## Number of terminal nodes: 15
<span class="n">plot</span><span class="p">(</span><span class="n">chaid.m1</span><span class="o">$</span><span class="n">finalModel</span><span class="p">)</span><span class="w">
</span>

<span class="n">confusionMatrix</span><span class="p">(</span><span class="n">chaid.m1</span><span class="p">)</span><span class="w">
</span>
## Cross-Validated (10 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No  Yes
##        No  79.0 13.0
##        Yes  4.9  3.1
##                             
##  Accuracy (average) : 0.8214
<span class="n">confusionMatrix</span><span class="p">(</span><span class="n">predict</span><span class="p">(</span><span class="n">chaid.m1</span><span class="p">),</span><span class="w"> </span><span class="n">y</span><span class="p">)</span><span class="w">
</span>
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  839 120
##        Yes  25  46
##                                           
##                Accuracy : 0.8592          
##                  95% CI : (0.8365, 0.8799)
##     No Information Rate : 0.8388          
##     P-Value [Acc > NIR] : 0.03938         
##                                           
##                   Kappa : 0.3228          
##  Mcnemar's Test P-Value : 5.89e-15        
##                                           
##             Sensitivity : 0.9711          
##             Specificity : 0.2771          
##          Pos Pred Value : 0.8749          
##          Neg Pred Value : 0.6479          
##              Prevalence : 0.8388          
##          Detection Rate : 0.8146          
##    Detection Prevalence : 0.9311          
##       Balanced Accuracy : 0.6241          
##                                           
##        'Positive' Class : No              
##
<span class="n">varImp</span><span class="p">(</span><span class="n">chaid.m1</span><span class="p">)</span><span class="w">
</span>
## ROC curve variable importance
## 
##   only 20 most important variables shown (out of 30)
## 
##                         Importance
## OverTime                    100.00
## YearsInCurrentRole           90.81
## YearsAtCompany               90.41
## MonthlyIncome                87.08
## JobLevel                     84.36
## TotalWorkingYears            80.04
## YearsWithCurrManager         79.78
## StockOptionLevel             69.51
## MaritalStatus                65.96
## Age                          59.31
## JobSatisfaction              44.86
## JobInvolvement               44.27
## DistanceFromHome             36.80
## EnvironmentSatisfaction      32.15
## WorkLifeBalance              31.63
## DailyRate                    30.23
## JobRole                      29.94
## NumCompaniesWorked           28.67
## Department                   25.79
## HourlyRate                   19.81
<span class="n">chaid.m1</span><span class="o">$</span><span class="n">bestTune</span><span class="w">
</span>
##   alpha2 alpha3 alpha4
## 3   0.05     -1   0.05
<span class="n">chaid.m1</span><span class="o">$</span><span class="n">times</span><span class="w">
</span>
## $everything
##    user  system elapsed 
## 247.218   1.581 248.999 
## 
## $final
##    user  system elapsed 
##   9.612   0.055   9.674 
## 
## $prediction
## [1] NA NA NA
<span class="n">chaid.m1</span><span class="o">$</span><span class="n">method</span><span class="w">
</span>
## [1] "chaid"
<span class="n">chaid.m1</span><span class="o">$</span><span class="n">modelInfo</span><span class="w">
</span>
## $label
## [1] "CHi-squared Automated Interaction Detection"
## 
## $library
## [1] "CHAID"
## 
## $loop
## NULL
## 
## $type
## [1] "Classification"
## 
## $parameters
##   parameter   class
## 1    alpha2 numeric
## 2    alpha3 numeric
## 3    alpha4 numeric
##                                                                                     label
## 1                                                                       Merging Threshold
## 2                                                       Splitting former Merged Threshold
## 3 \n                                                    Splitting former Merged Threshold
## 
## $grid
## function (x, y, len = NULL, search = "grid") 
## {
##     if (search == "grid") {
##         out <- data.frame(alpha2 = seq(from = 0.05, to = 0.01, 
##             length = len), alpha3 = -1, alpha4 = seq(from = 0.05, 
##             to = 0.01, length = len))
##     }
##     else {
##         out <- data.frame(alpha2 = runif(len, min = 1e-06, max = 0.1), 
##             alpha3 = runif(len, min = -0.1, max = 0.1), alpha4 = runif(len, 
##                 min = 1e-06, max = 0.1))
##     }
##     out
## }
## 
## $fit
## function (x, y, wts, param, lev, last, classProbs, ...) 
## {
##     dat <- if (is.data.frame(x)) 
##         x
##     else as.data.frame(x)
##     dat$.outcome <- y
##     theDots <- list(...)
##     if (any(names(theDots) == "control")) {
##         theDots$control$alpha2 <- param$alpha2
##         theDots$control$alpha3 <- param$alpha3
##         theDots$control$alpha4 <- param$alpha4
##         ctl <- theDots$control
##         theDots$control <- NULL
##     }
##     else ctl <- chaid_control(alpha2 = param$alpha2, alpha3 = param$alpha3, 
##         alpha4 = param$alpha4)
##     if (!is.null(wts)) 
##         theDots$weights <- wts
##     modelArgs <- c(list(formula = as.formula(".outcome ~ ."), 
##         data = dat, control = ctl), theDots)
##     out <- do.call(CHAID::chaid, modelArgs)
##     out
## }
## <bytecode: 0x7ff7fd0b48a8>
## 
## $predict
## function (modelFit, newdata, submodels = NULL) 
## {
##     if (!is.data.frame(newdata)) 
##         newdata <- as.data.frame(newdata)
##     predict(modelFit, newdata)
## }
## <bytecode: 0x7ff7f6851190>
## 
## $prob
## function (modelFit, newdata, submodels = NULL) 
## {
##     if (!is.data.frame(newdata)) 
##         newdata <- as.data.frame(newdata)
##     predict(modelFit, newdata, type = "prob")
## }
## 
## $levels
## function (x) 
## x$obsLevels
## 
## $predictors
## function (x, surrogate = TRUE, ...) 
## {
##     predictors(terms(x))
## }
## 
## $tags
## [1] "Tree-Based Model"           "Implicit Feature Selection"
## [3] "Two Class Only"             "Accepts Case Weights"      
## 
## $sort
## function (x) 
## x[order(-x$alpha2, -x$alpha4, -x$alpha3), ]
<span class="n">chaid.m1</span><span class="o">$</span><span class="n">results</span><span class="w">
</span>
##   alpha2 alpha3 alpha4  Accuracy     Kappa AccuracySD   KappaSD
## 1   0.01     -1   0.01 0.8223292 0.1522392 0.01887938 0.1278739
## 2   0.03     -1   0.03 0.8349699 0.1579585 0.02503052 0.1093852
## 3   0.05     -1   0.05 0.8213958 0.1692826 0.03353654 0.1180522

Let’s tune it up a little

Having mastered the basics of using caret and chaid let’s explore a
little deeper. By default caret allows us to adjust three parameters
in our chaid model; alpha2, alpha3, and alpha4. As a matter of
fact it will allow us to build a grid of those parameters and test all
the permutations we like, using the same cross-validation process. I’m a
bit worried that we’re not being conservative enough. I’d like to train
our model using p values for alpha that are not .05, .03, and .01 but
instead the de facto levels in my discipline; .05, .01, and .001. The
function in caret is tuneGrid. We’ll use the base R function
expand.grid to build a dataframe with all the combinations and then
feed it to caret in our next training.

Therefore search_grid will hold the values and we’ll add the line
tuneGrid = search_grid to our call to train. We’ll call the results
chaid.m2 and see how we did (I’m turning off verbose iteration output
since you’ve seen it on screen once already)…

<span class="c1"># set up tuning grid default</span><span class="w">
</span><span class="n">search_grid</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">expand.grid</span><span class="p">(</span><span class="w">
  </span><span class="n">alpha2</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">.05</span><span class="p">,</span><span class="w"> </span><span class="m">.01</span><span class="p">,</span><span class="w"> </span><span class="m">.001</span><span class="p">),</span><span class="w">
  </span><span class="n">alpha4</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">.05</span><span class="p">,</span><span class="w"> </span><span class="m">.01</span><span class="p">,</span><span class="w"> </span><span class="m">.001</span><span class="p">),</span><span class="w">
  </span><span class="n">alpha3</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="w">
</span><span class="p">)</span><span class="w">

</span><span class="c1"># no verbose</span><span class="w">
</span><span class="n">train_control</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">trainControl</span><span class="p">(</span><span class="n">method</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"cv"</span><span class="p">,</span><span class="w">
                              </span><span class="n">number</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">10</span><span class="p">,</span><span class="w">
                              </span><span class="n">savePredictions</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"final"</span><span class="p">)</span><span class="w">

</span><span class="c1"># train model</span><span class="w">
</span><span class="n">chaid.m2</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">train</span><span class="p">(</span><span class="w">
  </span><span class="n">x</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">x</span><span class="p">,</span><span class="w">
  </span><span class="n">y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w">
  </span><span class="n">method</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"chaid"</span><span class="p">,</span><span class="w">
  </span><span class="n">metric</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"Kappa"</span><span class="p">,</span><span class="w">
  </span><span class="n">trControl</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">train_control</span><span class="p">,</span><span class="w">
  </span><span class="n">tuneGrid</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">search_grid</span><span class="w">
</span><span class="p">)</span><span class="w">

</span><span class="n">chaid.m2</span><span class="w">
</span>
## CHi-squared Automated Interaction Detection 
## 
## 1030 samples
##   30 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 926, 927, 928, 928, 928, 926, ... 
## Resampling results across tuning parameters:
## 
##   alpha2  alpha4  Accuracy   Kappa    
##   0.001   0.001   0.8378522  0.2755221
##   0.001   0.010   0.8329691  0.2039261
##   0.001   0.050   0.8231655  0.2026735
##   0.010   0.001   0.8378522  0.2755221
##   0.010   0.010   0.8358914  0.2185542
##   0.010   0.050   0.8280863  0.2231160
##   0.050   0.001   0.8407648  0.2992935
##   0.050   0.010   0.8387949  0.2487845
##   0.050   0.050   0.8280296  0.2324447
## 
## Tuning parameter 'alpha3' was held constant at a value of -1
## Kappa was used to select the optimal model using the largest value.
## The final values used for the model were alpha2 = 0.05, alpha3 = -1
##  and alpha4 = 0.001.
<span class="n">plot</span><span class="p">(</span><span class="n">chaid.m2</span><span class="p">)</span><span class="w">
</span>

<span class="n">chaid.m2</span><span class="o">$</span><span class="n">finalModel</span><span class="w">
</span>
## 
## Model formula:
## .outcome ~ Age + BusinessTravel + DailyRate + Department + DistanceFromHome + 
##     Education + EducationField + EnvironmentSatisfaction + Gender + 
##     HourlyRate + JobInvolvement + JobLevel + JobRole + JobSatisfaction + 
##     MaritalStatus + MonthlyIncome + MonthlyRate + NumCompaniesWorked + 
##     OverTime + PercentSalaryHike + PerformanceRating + RelationshipSatisfaction + 
##     StockOptionLevel + TotalWorkingYears + TrainingTimesLastYear + 
##     WorkLifeBalance + YearsAtCompany + YearsInCurrentRole + YearsSinceLastPromotion + 
##     YearsWithCurrManager
## 
## Fitted party:
## [1] root
## |   [2] OverTime in No
## |   |   [3] YearsAtCompany in [0,2]: No (n = 165, err = 20.6%)
## |   |   [4] YearsAtCompany in (2,5], (5,7], (7,10], (10,40]: No (n = 566, err = 6.9%)
## |   [5] OverTime in Yes
## |   |   [6] JobLevel in 1: Yes (n = 117, err = 47.9%)
## |   |   [7] JobLevel in 2, 3, 4, 5: No (n = 182, err = 17.6%)
## 
## Number of inner nodes:    3
## Number of terminal nodes: 4
<span class="n">plot</span><span class="p">(</span><span class="n">chaid.m2</span><span class="o">$</span><span class="n">finalModel</span><span class="p">)</span><span class="w">
</span>

<span class="n">confusionMatrix</span><span class="p">(</span><span class="n">chaid.m2</span><span class="p">)</span><span class="w">
</span>
## Cross-Validated (10 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No  Yes
##        No  79.0 11.1
##        Yes  4.9  5.0
##                             
##  Accuracy (average) : 0.8408
<span class="n">confusionMatrix</span><span class="p">(</span><span class="n">predict</span><span class="p">(</span><span class="n">chaid.m2</span><span class="p">),</span><span class="w"> </span><span class="n">y</span><span class="p">)</span><span class="w">
</span>
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  808 105
##        Yes  56  61
##                                         
##                Accuracy : 0.8437        
##                  95% CI : (0.82, 0.8653)
##     No Information Rate : 0.8388        
##     P-Value [Acc > NIR] : 0.354533      
##                                         
##                   Kappa : 0.3436        
##  Mcnemar's Test P-Value : 0.000155      
##                                         
##             Sensitivity : 0.9352        
##             Specificity : 0.3675        
##          Pos Pred Value : 0.8850        
##          Neg Pred Value : 0.5214        
##              Prevalence : 0.8388        
##          Detection Rate : 0.7845        
##    Detection Prevalence : 0.8864        
##       Balanced Accuracy : 0.6513        
##                                         
##        'Positive' Class : No            
##
<span class="n">chaid.m2</span><span class="o">$</span><span class="n">times</span><span class="w">
</span>
## $everything
##    user  system elapsed 
## 524.972   3.729 529.873 
## 
## $final
##    user  system elapsed 
##   2.173   0.013   2.191 
## 
## $prediction
## [1] NA NA NA
<span class="n">chaid.m2</span><span class="o">$</span><span class="n">results</span><span class="w">
</span>
##   alpha2 alpha4 alpha3  Accuracy     Kappa AccuracySD    KappaSD
## 1  0.001  0.001     -1 0.8378522 0.2755221 0.02253555 0.09552095
## 2  0.001  0.010     -1 0.8329691 0.2039261 0.02263752 0.09977861
## 3  0.001  0.050     -1 0.8231655 0.2026735 0.03187552 0.12676157
## 4  0.010  0.001     -1 0.8378522 0.2755221 0.02253555 0.09552095
## 5  0.010  0.010     -1 0.8358914 0.2185542 0.02240334 0.10717030
## 6  0.010  0.050     -1 0.8280863 0.2231160 0.03056971 0.08137926
## 7  0.050  0.001     -1 0.8407648 0.2992935 0.02523390 0.10729121
## 8  0.050  0.010     -1 0.8387949 0.2487845 0.02277103 0.10696016
## 9  0.050  0.050     -1 0.8280296 0.2324447 0.03157911 0.13890292

Very nice! Some key points here. Even though our model got more
conservative and has far fewer nodes, our accuracy has improved as
measured both by traditional accuracy and Kappa. That applies at both
the average fold level but more importantly at the best model
prediction stage. Later on we’ll start using our models to predict
against the data we held out in test.

The plot is also more useful now. No matter what we do with alpha2 it
pays to keep alpha4 conservative at .001 (blue line always on top) but
keeping alpha2 modest seems to be best.

This goes to the heart of our conversation about over-fitting. While it
may seem like 1,400+ cases is a lot of data we are at great risk of
over-fitting if we try and build too complex a model, so sometimes a
conservative track is warranted.

A Custom caret model

Earlier I printed the results of chaid.m1$modelInfo and then pretty
much skipped over discussing them. Under the covers one of the strengths
of caret is that it keeps some default information about how to tune
various types of algorithms. They are visible at
https://github.com/topepo/caret/tree/master/models/files.

My experience is that they are quite comprehensive and allow you to get
your modelling done. But sometimes you want to do something your own way
or different and caret has provisions for that. If you look at the
default model setup for CHAID here on
GITHUB

you can see that it only allows you to tune on alpha2, alpha3, and
alpha4 by default. That is not a comprehensive list of all the
parameters we can work with in chaid_control see ?chaid_control for
a listing and brief description of what they all are.

What if, for example, we wanted to tune based upon minsplit,
minbucket, minprob, maxheight instead? How would we go about using
all the built in functionality in caret but have it our way? There’s a
section in the caret documentation called “Using Your Own Model In
Train”

that does a great job of walking you through the steps. At first it
looked a little too complicated for my tastes, but I found that with a
bit of trial and error I was able to hack up the existing list that I
found on GITHUB and convert it into a list in my local environment that
worked perfectly for my needs.

I won’t bore you with all the details and the documentation is quite
good so it wound up being mainly a search and replace operation and
adding one parameter. I decided to call my version cgpCHAID and here’s
what the version looks like.

<span class="c1"># hack up my own</span><span class="w">

</span><span class="n">cgpCHAID</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"CGP CHAID"</span><span class="p">,</span><span class="w">
                 </span><span class="n">library</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"CHAID"</span><span class="p">,</span><span class="w">
                 </span><span class="n">loop</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w">
                 </span><span class="n">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s2">"Classification"</span><span class="p">),</span><span class="w">
                 </span><span class="n">parameters</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">parameter</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s1">'minsplit'</span><span class="p">,</span><span class="w"> </span><span class="s1">'minbucket'</span><span class="p">,</span><span class="w"> </span><span class="s1">'minprob'</span><span class="p">,</span><span class="w"> </span><span class="s1">'maxheight'</span><span class="p">),</span><span class="w">
                                         </span><span class="n">class</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">rep</span><span class="p">(</span><span class="s1">'numeric'</span><span class="p">,</span><span class="w"> </span><span class="m">4</span><span class="p">),</span><span class="w">
                                         </span><span class="n">label</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="s1">'Numb obs in response where no further split'</span><span class="p">,</span><span class="w"> 
                                                   </span><span class="s2">"Minimum numb obs in terminal nodes"</span><span class="p">,</span><span class="w"> 
                                                   </span><span class="s2">"Minimum freq of obs in terminal nodes."</span><span class="p">,</span><span class="w">
                                                   </span><span class="s2">"Maximum height for the tree"</span><span class="p">)</span><span class="w">
                 </span><span class="p">),</span><span class="w">
                 </span><span class="n">grid</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">len</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">,</span><span class="w"> </span><span class="n">search</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">"grid"</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
                   </span><span class="k">if</span><span class="p">(</span><span class="n">search</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">"grid"</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
                     </span><span class="n">out</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">minsplit</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">20</span><span class="p">,</span><span class="m">30</span><span class="p">),</span><span class="w">
                                       </span><span class="n">minbucket</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">7</span><span class="p">,</span><span class="w">
                                       </span><span class="n">minprob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">0.05</span><span class="p">,</span><span class="m">0.01</span><span class="p">),</span><span class="w">
                                       </span><span class="n">maxheight</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w">
                   </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="p">{</span><span class="w">
                     </span><span class="n">out</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">data.frame</span><span class="p">(</span><span class="n">minsplit</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">20</span><span class="p">,</span><span class="m">30</span><span class="p">),</span><span class="w">
                                       </span><span class="n">minbucket</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">7</span><span class="p">,</span><span class="w">
                                       </span><span class="n">minprob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="m">0.05</span><span class="p">,</span><span class="m">0.01</span><span class="p">),</span><span class="w">
                                       </span><span class="n">maxheight</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="m">-1</span><span class="p">)</span><span class="w">
                   </span><span class="p">}</span><span class="w">
                   </span><span class="n">out</span><span class="w">
                 </span><span class="p">},</span><span class="w">
                 </span><span class="n">fit</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">y</span><span class="p">,</span><span class="w"> </span><span class="n">wts</span><span class="p">,</span><span class="w"> </span><span class="n">param</span><span class="p">,</span><span class="w"> </span><span class="n">lev</span><span class="p">,</span><span class="w"> </span><span class="n">last</span><span class="p">,</span><span class="w"> </span><span class="n">classProbs</span><span class="p">,</span><span class="w"> </span><span class="n">...</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
                   </span><span class="n">dat</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="k">if</span><span class="p">(</span><span class="n">is.data.frame</span><span class="p">(</span><span class="n">x</span><span class="p">))</span><span class="w"> </span><span class="n">x</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="n">as.data.frame</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="w">
                   </span><span class="n">dat</span><span class="o">$</span><span class="n">.outcome</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">y</span><span class="w">
                   </span><span class="n">theDots</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">list</span><span class="p">(</span><span class="n">...</span><span class="p">)</span><span class="w">
                   </span><span class="k">if</span><span class="p">(</span><span class="nf">any</span><span class="p">(</span><span class="nf">names</span><span class="p">(</span><span class="n">theDots</span><span class="p">)</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">"control"</span><span class="p">))</span><span class="w"> </span><span class="p">{</span><span class="w">
                     </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="o">$</span><span class="n">minsplit</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minsplit</span><span class="w">
                     </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="o">$</span><span class="n">minbucket</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minbucket</span><span class="w">
                     </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="o">$</span><span class="n">minprob</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minprob</span><span class="w">
                     </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="o">$</span><span class="n">maxheight</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">maxheight</span><span class="w">
                     </span><span class="n">ctl</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="w">
                     </span><span class="n">theDots</span><span class="o">$</span><span class="n">control</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="kc">NULL</span><span class="w">
                   </span><span class="p">}</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="n">ctl</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">chaid_control</span><span class="p">(</span><span class="n">minsplit</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minsplit</span><span class="p">,</span><span class="w">
                                               </span><span class="n">minbucket</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minbucket</span><span class="p">,</span><span class="w">
                                               </span><span class="n">minprob</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">minprob</span><span class="p">,</span><span class="w">
                                               </span><span class="n">maxheight</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">param</span><span class="o">$</span><span class="n">maxheight</span><span class="p">)</span><span class="w">
                   </span><span class="c1">## pass in any model weights</span><span class="w">
                   </span><span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="nf">is.null</span><span class="p">(</span><span class="n">wts</span><span class="p">))</span><span class="w"> </span><span class="n">theDots</span><span class="o">$</span><span class="n">weights</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">wts</span><span class="w">
                   </span><span class="n">modelArgs</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="nf">c</span><span class="p">(</span><span class="w">
                     </span><span class="nf">list</span><span class="p">(</span><span class="w">
                       </span><span class="n">formula</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">as.formula</span><span class="p">(</span><span class="s2">".outcome ~ ."</span><span class="p">),</span><span class="w">
                       </span><span class="n">data</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">dat</span><span class="p">,</span><span class="w">
                       </span><span class="n">control</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">ctl</span><span class="p">),</span><span class="w">
                     </span><span class="n">theDots</span><span class="p">)</span><span class="w">
                   </span><span class="n">out</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">do.call</span><span class="p">(</span><span class="n">CHAID</span><span class="o">::</span><span class="n">chaid</span><span class="p">,</span><span class="w"> </span><span class="n">modelArgs</span><span class="p">)</span><span class="w">
                   </span><span class="n">out</span><span class="w">
                 </span><span class="p">},</span><span class="w">
                 </span><span class="n">predict</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="k">function</span><span class="p">(</span><span class="n">modelFit</span><span class="p">,</span><span class="w"> </span><span class="n">newdata</span><span class="p">,</span><span class="w"> </span><span class="n">submodels</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="kc">NULL</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w">
                   </span><span class="k">if</span><span class="p">(</span><span class="o">!</span><span class="n">is.data.frame</span><span class="p">(</span><span class="n">newdata</span><span class="p">))</span><span class="w"> </span><span class="n">newdata</span><span class="w"> </span><span class="o"><-</span><span class="w"> </span><span class="n">as.data.frame</span><span class="p">(</span><span class=&qu...

To leave a comment for the author, please follow the link and comment on their blog: Chuck Powell.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)