Efficient Partial Dependence Plots with decision trees

Partial Dependence Plots (PDPs) are a standard inspection technique for machine learning models. There are two ways to compute PDPs:

  1. The slow and generic way that works for any model
  2. The fast way that only works for models that are based on (regression) decision trees. As far as I know, the fast method was originally presented in Friedman’s Gradient Boosting paper (pdf link).

This post will describe both techniques, and explain why the fast way is… well, faster. We will also see that they are not always equivalent.

Partial Dependence: definition

We will briefly describe partial dependence functions. For a more thorough introduction to PDPs, you can refer to the Bible, or to the Interpretable Machine Learning book.

In what follows, we will keep things simple and only consider a dataset with two features. The first feature $X_0$ is the one we want to get a PDP for, and the second feature $X_1$ is the one that will get averaged out. Following the notation from the references above, $X_0$ corresponds to $X_S$ and $X_1$ corresponds to $X_C$

The partial dependence on $X_0$ is:

where $f(\cdot, \cdot)$ is our model, taking both features as input. This is also called the average dependence, which IMHO is a better name because it’s clear that the rest of the features get averaged out.

The slow way in general

For a given value $x$ of $X_0$, the slow method approximates $pd_{X_0}(x)$ with an average over some data. This is usually the training set, but that could be also be some validation data:

where $n$ is the number of samples and $x_1^{(i)}$ is the value of feature $X_1$ for the $i$th sample. To compute the above approximation, we need to:

Repeating that procedure for each value $x$ of $X_0$, we end up with a PDP like the one below.

The issue here is that for each value $x$, we need a full pass on the dataset. That’s slow, but that’s a very generic way that will work for any model.

The slow way on decision trees

Let’s consider the following tree. Each non-leaf node indicates whether its split is on $X_0$ or on $X_1$:

That tree will partition the input space as in the following plot (thresholds are arbitrary). Each line corresponds to a split, i.e. a non-leaf node, and each rectangle corresponds to a leaf, with its own value.

Let’s apply the slow method to compute $pd_{X_0}(5)$. We need to create fake samples (whose values for $X_1$ depend on our dataset at hand), and then compute the prediction of the tree for each of these fake samples. Our fake samples are represented as dots in the following plot (notice that they naturally all have $X_0=5$). The prediction for a given sample corresponds to the value of the leaf where the sample lands on.

To approximate $pd_{X_0}(5)$, we now average these predictions following the formula above, and get:

Repeat this for all values of $X_0$, and you get a PDP.

The fast way on decision trees

In blue are the paths that the fake samples have followed:

We then just computed the proportion of fake samples landing in each leaf (H, I and G), and averaged the predictions. The key to the fast way is that we can compute the proportions ($\frac{1}{2}, \frac{1}{6}, \frac{1}{3}$) without having to create the fake samples. We’ll use the fact that each node of the tree remembers how many training samples went through it during the training stage.

For a given value $x$, the fast way simulates the $n$ tree traversals of the fake samples in a single tree traversal. During the traversal, we keep track of the proportions of training samples that would have followed each path (the “would” part is important here):

# input: x, a value of X0
# output: pd, an approximation of the expectation pd_X0(x)
pd = 0
def dfs(node, prop):
    if node.is_leaf:
        pd += prop * node.value
        return

    if node.split_feature is X0:
        # follow normal path: either left or right child
        child = node.left if x <= node.threshold else node.right
        dfs(child, prop)
    else:
        # follow both left and right, with appropriate proportions
        prop_samples_left = node.left.n_train_samples / node.n_train_samples
        dfs(node.left, prop * prop_samples_left)
        dfs(node.right, prop * (1 - prop_samples_left))

dfs(root, prop=1)

This is almost a regular tree dfs traversal. The main twist is that when a node splits on $X_1$, we follow both children instead of just one of them. You can think of the fast way as doing exactly what the slow way would do with the training data, that is, creating fake samples from the training data and passing them through the tree. Except that, in the fast version, we don’t explicitly create the fake samples, we just simulate them by keeping track of the proportions.

The nodes that are visited by the fast method are the same nodes that are visited by the fake samples, had they been created. In particular, notice in the tree plot above that when the split is on $X_0$, only one of the children is followed. When the split is on $X_1$, both children are followed, by a given proportions of samples.

We only needed one traversal instead of $n$ traversals, and we got the same result!

All this generalizes very naturally to more than one feature in $X_S$ or in $X_C$. If you’re curious about how we implemented this in scikit-learn, here are the slow and the fast method implementations.

Differences beween the fast and the slow method

The slow and the fast methods are almost equivalent, but they might differ in two ways.

First, they can only be equivalent if the slow method is using the training data. Indeed, both methods try to approximate the expectation $\mathbb{E}_{X_1}\left[ f(x, X_1) \right]$. The difference lies in the values of $X_1$ that are used to approximate this expectation with an average: the fast method always uses the values of the training data, the slow one uses whatever you feed it. As long as the data used by the slow method follows the same distribution as the training data, the differences should be in general minimal.

The second reason is trickier to grasp. Consider the degenerated dataset below, where all points have a target value of 0 except for one of them (in yellow) whose target value is 1000.

The corresponding tree is:

As we’ll see, the fast and the slow method will strongly disagree on the value of $pd_{X_0}(4)$.

The fast way will (virtually) map all the samples to the root’s right child, and then give a weight of $\frac{1}{2}$ to its left leaf, and a weight of $\frac{1}{2}$ to its right leaf. This results in $pd_{X_0}(4) = \frac{1}{2} \times 1000 + \frac{1}{2} \times 0 = 500$

The slow way will create $n$ fake samples who all fall on the vertical line $X_0 = 4$, and whose value for $X_1$ correspond to those present in the training set. All these fake samples will be mapped to the root’s right child, just like in the fast method. But then, the vast majority of these fake samples have $X_1 < 9.5$. So most of them will have a predicted value of 1000 while only a handful will have a predicted value of 0. Taking the average of the predictions, the partial dependence will be estimated at $pd_{X_0}(4) = 950$, which is very far from the 500 as predicted by the fast method.

Which one is correct? Probably none of them.

The slow method is creating fake samples that have unrealistic characteristics. It’s clear from the training data that samples with $X_0 > 0$ are very unlikely, and yet the slow method relies on the assumption that such samples exist.

On the other hand, the fast method considers that $\frac{1}{2}$ / $\frac{1}{2}$ is a reasonable approximation of the distribution of samples at that node. It might be true, but clearly the proportions are computed on the basis of very few samples (only 2), so they are not necessarily reliable.

As always with interpretation methods, one needs to be careful and always check the assumptions of method that is being used.


Nicolas Hug

Nicolas Hug

Associate Research Scientist at Columbia University

comments powered by Disqus
rss facebook twitter github youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora