Draw nicer Classification and Regression Trees with the rpart.plot package

[This article was first published on Revolutions, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

by Joseph Rickert

The basic way to plot a classification or regression tree built with R’s rpart() function is just to call plot. However, in general, the results just aren’t pretty. As it turns out, for some time now there has been a better way to plot rpart() trees: the prp() function in Stephen Milborrow’s rpart.plot package. This function is a veritable “Swiss Army Knife” for plotting trees and the documentation for the package is quite good: in addition to the package pdf, Stephen maintains a very nice and easy-to-read user manual on his webpage. However, it wasn’t until I stumbled onto Graham Williams recently posted  “one pager” about plotting trees that I realized just how powerful and flexible prp() is. This document is packed packed with ideas about drawing trees with prp() and is a very useful supplement to Graham’s superb book: Data Mining with Rattle and R.

For anyone not familiar with prp(), the following script is an idiosyncratic first look and, I hope, motivation for further investigation.

# Plotting Classification Trees with the plot.rpart and rattle pckages

library(rpart)				        # Popular decision tree algorithm
library(rattle)					# Fancy tree plot
library(rpart.plot)				# Enhanced tree plots
library(RColorBrewer)				# Color selection for fancy tree plot
library(party)					# Alternative decision tree algorithm
library(partykit)				# Convert rpart object to BinaryTree
library(caret)					# Just a data source for this script
						# but probably one of the best R packages ever. 
data(segmentationData)				# Get some data
data <- segmentationData[,-c(1,2)]
# Make big tree
form <- as.formula(Class ~ .)
tree.1 <- rpart(form,data=data,control=rpart.control(minsplit=20,cp=0))
plot(tree.1)					# Will make a mess of the plot
prp(tree.1)					# Will plot the tree
prp(tree.1,varlen=3)				# Shorten variable names

# Interatively prune the tree
new.tree.1 <- prp(tree.1,snip=TRUE)$obj # interactively trim the tree
prp(new.tree.1) # display the new tree
tree.2 <- rpart(form,data)			# A more reasonable tree
prp(tree.2)                                     # A fast plot													
fancyRpartPlot(tree.2)				# A fancy plot from rattle
# Plot a tree built with RevoScaleR
# Construct a model formula
sdNames <- names(segmentationData)
X <- as.vector(sdNames[-c(1,2,3)])
form <- as.formula(paste("Class","~", paste(X,collapse="+")))
# Run the model
rx.tree <- rxDTree(form, data = segmentationData,maxNumBins = 100,
                   minBucket = 10,maxDepth = 5,cp = 0.01, xVal = 0)
# Plot the tree						

After loading the data, the script makes some deliberately ill-advised choices in building an rpart() classification tree for the segmentationData from the caret package. Plotting the tree with plot() (not shown) produces an a couple of black clouds of overlaid text that is fairly typical of what you could expect from an attempt to naively plot a large tree. However, prp() does a pretty good job of plotting the tree and revealing its structure with just the default settings. And, using a parameter instructing prp() to abbreviate variable names makes the plot even more readable.


The next few lines of code show off the prp()’s interactive pruning capability. The line that assigns the object new.tree.1 produces a “live” tree plot. Use the mouse to prune the tree, hit “QUIT” and replot and you have a fairly nice plot for the top part of the tree. This is slick way to get a legible picture of the top part of a tree into a report.

tree.2, a much more reasonable tree for the segmentationData, results from just accepting the rpart() deaults. First, this tree is plotted with prp() using the default settings and then, in the next line, the tree is plotted using the fancyRpartPlot() function from Graham Williams rattle package.


This function is just a wrapper for prp() but is easy to use for plotting classification trees and is a very nice example of how aesthetics can facilitate communication. Each node box displays the classification, the probability of each class at that node (i.e. the probability of the class conditioned on the node) and the percentage of observations used at that node. Notice how the use of the dotted lines tends to emphasize the nodes and not the tree itself, and how having the bottom level of leaves line up helps the viewer to guess that the percentages in the node boxes indicate the percentage of observations that arrived at each node. (The bottom row adds to 100%, .. well almost). fancyRpartPlot() is well worth studying as an example of how to choose among the daunting number of parameter options available for prp().

The last few lines of the script construct a tree using the rxDTree() function for building classification and regression tree models with massive data sets. This function is included in the RevoScaleR package that ships with Revolution R Enterprise. Both prp() and fancyRpartPlot() work well for these trees too. It is just necessary to run the rxAddInheritance() function on the rxDTree() model object to make sure prp() knows what to do with it. The parameter settings for rxDTree() should yield a plot very similar to the fancy plot above. Happy plotting!

To leave a comment for the author, please follow the link and comment on their blog: Revolutions.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)