# How Could Classification Trees Be So Fast on Categorical Variables?

December 8, 2015
By

Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

I think that over the past months, I have been saying non-correct things about classification with categorical covariates. Because I never took time to look at it carefuly. Consider some simulated dataset, with a logistic regression,

```> n=1e3
> set.seed(1)
> X1=runif(n)
> q=quantile(X1,(0:26)/26)
> q[1]=0
> X2=cut(X1,q,labels=LETTERS[1:26])
> p=exp(-.1+qnorm(2*(abs(.5-X1))))/(1+exp(-.1+qnorm(2*(abs(.5-X1)))))
> Y=rbinom(n,size=1,p)
> df=data.frame(X1=X1,X2=X2,p=p,Y=Y)```

Here, we use some continuous covariate$X_1$, except that$X_1$ is considered as not-observed. Instead, we have a categorical covariate with 26 categories. The (theoretical) relationship between the covariate and the probability is given below,

```> vx1=seq(0,1,by=.001)
> vp=exp(-.1+qnorm(2*(abs(.5-vx1))))/(1+exp(-.1+qnorm(2*(abs(.5-vx1)))))
> plot(vx1,vp,type="l")```

and the empirical probability, for each modality is

If we run a classification tree, we get

```> library(rpart)
> tree=rpart(Y~X2,data=df)
> library(rpart.plot)
> prp(tree, type=2, extra=1)```

To be more specific, the output is here

```> tree
1) root 1000 249.90000 0.4900000
2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.3 0.302
4) X2=J,K,L,M,N,O,P,Q,R 346  65.12 0.25144  *
5) X2=F,G,H,I 153  37.22876 0.4183007       *
3) X2=A,B,C,D,E,S,T,U,V,W,X,Y,Z 501 109.61 0.67
6) X2=B,C,D,E,S,T,U,V,W,X 385  90.38 0.623  *
7) X2=A,Y,Z 116  14.50862 0.8534483         *```

Note that it takes less than a second to get that output. So clearly, we did not look for all combinations between modalities. For the first node, there are like $2^{26}$ possible groups, i.e.

`> 67108864`

It is big… not huge, but too big to try all combinations, since that’s only the first node, and we have to do it again on the two leaves, etc. Antoine (aka @ly_antoine) told me – while we were having a coffee after lunch today – the trick to get a fast algorithm, on categories. And as usual, the idea is very clever…

First, we need a function to compute Gini index

```> gini=function(y,classe){
+    T=table(y,classe)
+    nx=apply(T,2,sum)
+    n=sum(T)
+    pxy=T/matrix(rep(nx,each=2),nrow=2)
+    omega=matrix(rep(nx,each=2),nrow=2)/n
+    g=-sum(omega*pxy*(1-pxy))
+    return(g)}```

For the first node, the idea is very simple:

• Compute empirical averages $p_i=\mathbb{E}[Y \vert X=\{x_i\}]$
`> cond_prob=aggregate(df\$Y,by=list(df\$X2),mean)`
• Then sort those values, $p_{1:k}\leq p_{2:k} \leq \cdots \leq p_{k:k}$,
• Based on that ordering, consider $\{x_{1:k},x_{2:k},\cdots,x_{k:k}\}$
`> Group_Letters=cond_prob[order(cond_prob\$x),2]`

• Then consider (only) $k$ possible partitions,

$\{x_{1:k},x_{2:k},\cdots,x_{i:k}\}$ against $\{x_{i+1:k},x_{i+2:k},\cdots,x_{k:k}\}$

```> v_gini=rep(NA,26)
> for(v in 1:26){
+   CLASSE=df\$X2 %in% Group_Letters[1:v]
+   v_gini[v]=gini(y=df\$Y,classe=CLASSE)
+ }```

If we plot them, we get

`> plot(1:26,v_gini,type="b)`

As for continuous variables, we seek for the maximum value, and then, we have our two groups,

```> sort(Group_Letters[1:which.max(v_gini)])
[1] F G H I J K L M N O P Q R```

That’s exactly what we got with the tree function in R,

```1) root 1000 249.90000 0.4900000
2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.30 0.30```

Now, consider the leaf on the left (for instance)

`> sub_df=df[df\$X2 %in% sort(Group_Letters[1:which.max(v_gini)]),]`

Then use the same algorithm as before: sort the conditional means,

```> cond_prob=aggregate(sub_df\$Y,by=
+ list(sub_df\$X2),mean)
> s_Group_Letters=cond_prob[order(cond_prob\$x),2]```

Then compute Gini indices based on groups obtained from that ordering,

```> v_gini=rep(NA,length(sub_Group_Letters))
> for(v in 1:length(sub_Group_Letters)){
+   CLASSE=sub_df\$X2 %in% s_Group_Letters[1:v]
+   v_gini[v]=gini(y=sub_df\$Y,classe=CLASSE)
+ }```

If we plot it, we get our two groups,

```> plot(1:length(s_Group_Letters),v_gini,type="b")
```

And the first group is here

```> sort(sub_Group_Letters[1:which.max(v_gini)])
[1] J K L M N O P Q R```

Again, that’s exactly what we got with the R function

```1) root 1000 249.90000 0.4900000
2) X2=F,G,H,I,J,K,L,M,N,O,P,Q,R 499 105.30 0.30
4) X2=J,K,L,M,N,O,P,Q,R 346  65.12 0.25144  *```

Clever, isn’t?

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.