Ensemble Methods are Doomed to Fail in High Dimensions

March 15, 2017
By

(This article was first published on R – Statistical Modeling, Causal Inference, and Social Science, and kindly contributed to R-bloggers)

Ensemble methods

By ensemble methods, I (Bob, not Andrew) mean approaches that scatter points in parameter space and then make moves by inteprolating or extrapolating among subsets of them. Two prominent examples are:

There are extensions and computer implementations of these algorithms. For example, the Python package emcee implements Goodman and Weare’s walker algorithm and is popular in astrophysics.

Typical sets in high dimensions

If you want to get the background on typical sets, I’d highly recommend Michael Betancourt’s video lectures on MCMC in general and HMC in particular; they both focus on typical sets and their relation to the integrals we use MCMC to calculate:

It was Michael who made a doughnut in the air, pointed at the middle of it and said, “It’s obvious ensemble methods won’t work.” This is just fleshing out the details with some code for the rest of us without such sharp geometric intuitions.

MacKay’s information theory book is another excellent source on typical sets. Don’t bother with the Wikipedia on this one.

Why ensemble methods fail: Executive summary

  1. We want to draw a sample from the typical set
  2. The typical set is a thin shell a fixed radius from the mode in a multivariate normal
  3. Interpolating or extrapolating two points in this shell is unlikely to fall in this shell
  4. The only steps that get accepted will be near one of the starting points
  5. The samplers devolve to a random walk with poorly biased choice of direction

Several years ago, Peter Li built the Goodman and Weare walker methods for Stan (all they need is log density) on a branch for evaluation. They failed in practice exactly the way the theory says they will fail. Which is too bad, because the ensemble methods are very easy to implement and embarassingly parallel.

Why ensemble methods fail: R simulation

OK, so let’s see why they fail in practice. I’m going to write some simple R code to do the job for us. Here’s an R function to generate a 100-dimensional standard isotropic normal variate (each element is generated normal(0, 1) independently):

normal_rng <- function(K) rnorm(K);

This function computes the log density of a draw:

normal_lpdf <- function(y) sum(dnorm(y, log=TRUE));

Next, generate two draws from a 100-dimesnional version:

K <- 100;
y1 <- normal_rng(K);
y2 <- normal_rng(K);

and then interpolate by choosing a point between them:

lambda = 0.5;
y12 <- lambda * y1 + (1 - lambda) * y2;

Now let's see what we get:

print(normal_lpdf(y1), digits=1);
print(normal_lpdf(y2), digits=1);
print(normal_lpdf(y12), digits=1);

[1] -153
[1] -142
[1] -123

Hmm. Why is the log density of the interpolated vector so much higher? Given that it's a multivariate normal, the answer is that it's closer to the mode. That should be a good thing, right? No, it's not. The typical set is defined as an area within "typical" density bounds. When I take a random draw from a 100-dimensional standard normal, I expect log densities that hover between -140 and -160 or so. That interpolated vector y12 with a log density of -123 isn't in the typical set!!! It's a bad draw, even though it's closer to the mode. Still confused? Watch Michael's videos above. Ironically, there's a description in the Goodman and Weare paper in a discussion of why they can use ensemble averages that also explains why their own sampler doesn't scale---the variance of averages is lower than the variance of individual draws; and we want to cover the actual posterior, not get closer to the mode.

So let's put this in a little sharper perspective by simulating thousands of draws from a multivariate normal and thousands of draws interpolating between pairs of draws and plot them in two histograms. First, draw them and print a summary:

lp1 <- vector();
for (n in 1:1000) lp1[n] <- normal_lpdf(normal_rng(K));
print(summary(lp1));

lp2 <- vector()
for (n in 1:1000) lp2[n] <- normal_lpdf((normal_rng(K) + normal_rng(K))/2);
print(summary(lp2));

from which we get:

  Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
   -177    -146    -141    -142    -136    -121 

   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
   -129    -119    -117    -117    -114    -108 

That's looking bad. It's even clearer with a faceted histogram:

library(ggplot2);
df <- data.frame(list(log_pdf = c(lp1, lp2),
                      case=c(rep("Y1", 1000), rep("(Y1 + Y2)/2", 1000))));
plot <- ggplot(df, aes(log_pdf)) +
        geom_histogram(color="grey") +
        facet_grid(case ~ .);

Here's the plot:

The bottom plot shows the distribution of log densities in independent draws from the standard normal (these are pure Monte Carlo draws). The top plot shows the distribution of the log density of the vector resulting from interpolating two independent draws from the same distribution. Obviously, the log densities of the averaged draws are much higher. In other words, they are atypical of draws from the target standard normal density.

Exercise

Check out what happens as (1) the number of dimensions K varies, and (2) as lambda varies within or outside of [0, 1].

Hint: What you should see is that as lambda approaches 0 or 1, the draws get more and more typical, and more and more like random walk Metropolis with a small step size. As dimensionality increases, the typical set becomes more attenuated and the problem becomes worse (and vice-versa as it decreases).

Does Hamiltonian Monte Carlo (HMC) have these problems?

Not so much. It scales much better with dimension. It'll slow down, but it won't break and devolve to a random walk like ensemble methods do.

The post Ensemble Methods are Doomed to Fail in High Dimensions appeared first on Statistical Modeling, Causal Inference, and Social Science.

To leave a comment for the author, please follow the link and comment on their blog: R – Statistical Modeling, Causal Inference, and Social Science.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.

Sponsors

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)