Visualizing trees with Sklearn

[This article was first published on R – Hi! I am Nagdev, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Tree-based models are probably the second easiest ML technique for explaining the model to a non-data scientist. I am a big fan of tree-based models because of their simplicity and interpretability. But, when I try to visualize them is, when it gets my nerves. There are so many packages out there to visualize them. Sklearn has finally provided us with a new API to visualize trees through matplotlib. In this tutorial, I will show you how to visualize trees using sklearn for both classification and regression.

Importing libraries

The following are the libraries that are required to load datasets, split data, train models and visualize them.

from sklearn.datasets import load_wine, fetch_california_housing
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree, DecisionTreeClassifier, DecisionTreeRegressor

Classification

In this section, our objective is to

  1. Load wine dataset
  2. Split the data into train and test
  3. Train a decision tree classifier
  4. Visualize the decision tree
# load wine data set
data = load_wine()
x = data.data
y = data.target

# split into train and test data
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42)

# create a decision tree classifier
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(x_train, y_train)

# plot classifier tree
plt.figure(figsize=(10,8))
plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True)

Once you execute the above code, you should have the following or similar decision tree for the wine dataset model.

Classification tree

Regression

Similar to classification, in this section, we will train and visualize a model for regression

  1. Load california housing dataset
  2. Split the data into train and test
  3. Train a decision tree regressor
  4. Visualize the decision tree
# load data set
data = fetch_california_housing()
x = data.data
y = data.target

# split into train and test data
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42)

# create a decision tree regressor
clf = DecisionTreeRegressor(max_depth=2, random_state=0)
clf.fit(x_train, y_train)

# plot tree regressor
plt.figure(figsize=(10,8))
plot_tree(clf, feature_names=data.feature_names, filled=True)

Once you execute the following code, you should end with a graph similar to the one below.

Regression tree

As you can see, visualizing a decision tree has become a lot simpler with sklearn models. In the past, it would take me about 10 to 15 minutes to write a code with two different packages that can be done with two lines of code. I am definitely looking forward to future updates that support random forest and ensemble models.

Thank you for going through this article. Kindly post below if you have any questions or comments below.

You can also find code for this on my Github page.

The post Visualizing trees with Sklearn appeared first on Hi! I am Nagdev.

To leave a comment for the author, please follow the link and comment on their blog: R – Hi! I am Nagdev.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)