Create SQL Rules from rpart model

July 19, 2013
By

(This article was first published on R (en) - Analytik dat, and kindly contributed to R-bloggers)

Mapping output of rpart tree to SQL statements is not easy. In rpart package you have to print out rules and then manually write SQL CASE statement. Fortunately, we can write new function to do this job.

To test the function, I will use dataset german_data, located on github:

library(devtools)
install_github(repo="riv",username="tomasgreif")
library(riv)

First we create some (rather naive) model:

x <- german_data
x$gbbin <- NULL
model <- rpart(data=x,formula=gb~.)

The result model has a lot of leafs:

  1) root 1000 300 good (0.7000000 0.3000000)  
    2) ca_status=A13,A14 457  60 good (0.8687090 0.1312910) *
    3) ca_status=A11,A12 543 240 good (0.5580110 0.4419890)  
      6) mob< 22.5 306 106 good (0.6535948 0.3464052)  
       12) credit_history=A32,A33,A34 278  85 good (0.6942446 0.3057554)  
         24) credit_amount< 7491.5 271  79 good (0.7084871 0.2915129)  
           48) purpose=A40,A41,A410,A42,A43,A45,A48,A49 256  69 good (0.7304688 0.2695312)  
             96) mob< 11.5 73   9 good (0.8767123 0.1232877) *
             97) mob>=11.5 183  60 good (0.6721311 0.3278689)  
              194) credit_amount>=1387.5 118  29 good (0.7542373 0.2457627) *
              195) credit_amount< 1387.5 65  31 good (0.5230769 0.4769231)  
                390) property=A121,A122 45  14 good (0.6888889 0.3111111) *
                391) property=A123,A124 20   3 bad (0.1500000 0.8500000) *
           49) purpose=A44,A46 15   5 bad (0.3333333 0.6666667) *
         25) credit_amount>=7491.5 7   1 bad (0.1428571 0.8571429) *
       13) credit_history=A30,A31 28   7 bad (0.2500000 0.7500000) *
      7) mob>=22.5 237 103 bad (0.4345992 0.5654008)  
       14) savings=A64,A65 41  12 good (0.7073171 0.2926829) *
       15) savings=A61,A62,A63 196  74 bad (0.3775510 0.6224490)  
         30) mob< 47.5 160  69 bad (0.4312500 0.5687500)  
           60) purpose=A41 23   6 good (0.7391304 0.2608696) *
           61) purpose=A40,A410,A42,A43,A45,A46,A49 137  52 bad (0.3795620 0.6204380) *
         31) mob>=47.5 36   5 bad (0.1388889 0.8611111) *

Now, we can call function parse_tree:

parse_tree(x,model)

And we get the following result:

case  when ca_status in ('A13','A14') then 'node_2' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob <  11.5 then 'node_96' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount >= 1388 then 'node_194' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount <  1388 AND property in ('A121','A122') then 'node_390' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount <  1388 AND property in ('A123','A124') then 'node_391' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A44','A46') then 'node_49' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount >= 7492 then 'node_25' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A30','A31') then 'node_13' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A64','A65') then 'node_14' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob <  47.5 AND purpose in ('A41') then 'node_60' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob <  47.5 AND purpose in ('A40','A410','A42','A43','A45','A46','A49') then 'node_61' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob >= 47.5 then 'node_31'  end

This is valid SQL that can be used in most database engines (I'm using this in SQLite and PostgreSQL).

The function parse_tree has two arguments - data frame and model. It is necessary that variables in model exist in data frame and are of the same type.  You can find parse_tree function on github. Let me know if this works for you.

To leave a comment for the author, please follow the link and comment on his blog: R (en) - Analytik dat.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.