RStanARM basics: visualizing uncertainty in linear regression
Want to share your content on Rbloggers? click here if you have a blog, or here if you don't.
As part of my tutorial talk on RStanARM, I presented some examples of how to visualize the uncertainty in Bayesian linear regression models. This post is an expanded demonstration of the approaches I presented in that tutorial.
Data: Does brain mass predict how much mammals sleep in a day?
Let’s use the mammal sleep dataset from ggplot2. This dataset contains the number of hours spent sleeping per day for 83 different species of mammals along with each species’ brain mass (kg) and body mass (kg), among other measures. Here’s a first look at the data.
library(dplyr, warn.conflicts = FALSE)
library(ggplot2)
# Preview sorted by brain/body ratio. I chose this sorting so that humans would
# show up in the preview.
msleep %>%
select(name, sleep_total, brainwt, bodywt, everything()) %>%
arrange(desc(brainwt / bodywt))
#> # A tibble: 83 × 11
#> name sleep_total brainwt bodywt genus
#> <chr> <dbl> <dbl> <dbl> <chr>
#> 1 Thirteenlined ground squirrel 13.8 0.00400 0.101 Spermophilus
#> 2 Owl monkey 17.0 0.01550 0.480 Aotus
#> 3 Lesser shorttailed shrew 9.1 0.00014 0.005 Cryptotis
#> 4 Squirrel monkey 9.6 0.02000 0.743 Saimiri
#> 5 Macaque 10.1 0.17900 6.800 Macaca
#> 6 Little brown bat 19.9 0.00025 0.010 Myotis
#> 7 Galago 9.8 0.00500 0.200 Galago
#> 8 Mole rat 10.6 0.00300 0.122 Spalax
#> 9 Tree shrew 8.9 0.00250 0.104 Tupaia
#> 10 Human 8.0 1.32000 62.000 Homo
#> # ... with 73 more rows, and 6 more variables: vore <chr>, order <chr>,
#> # conservation <chr>, sleep_rem <dbl>, sleep_cycle <dbl>, awake <dbl>
ggplot(msleep) +
aes(x = brainwt, y = sleep_total) +
geom_point()
#> Warning: Removed 27 rows containing missing values (geom_point).
Hmmm, not very helpful! We should put our measures on a log10 scale. Also, 27 of the species don’t have brain mass data, so we’ll exclude those rows for the rest of this tutorial.
msleep < msleep %>%
filter(!is.na(brainwt)) %>%
mutate(log_brainwt = log10(brainwt),
log_bodywt = log10(bodywt),
log_sleep_total = log10(sleep_total))
Now, plot the logtransformed data. But let’s also get a little fancy and label the points for some example critters :cat: so that we can get some intuition about the data in this scaling. (Plus, I wanted to try out the annotation tips from the R4DS book.)
# Create a separate dataframe of species to highlight
ex_mammals < c("Domestic cat", "Human", "Dog", "Cow", "Rabbit",
"Big brown bat", "House mouse", "Horse", "Golden hamster")
# We will give some familiar species shorter names
renaming_rules < c(
"Domestic cat" = "Cat",
"Golden hamster" = "Hamster",
"House mouse" = "Mouse")
ex_points < msleep %>%
filter(name %in% ex_mammals) %>%
mutate(name = stringr::str_replace_all(name, renaming_rules))
# Define these labels only once for all the plots
lab_lines < list(
brain_log = "Brain mass (kg., logscaled)",
sleep_raw = "Sleep per day (hours)",
sleep_log = "Sleep per day (loghours)"
)
ggplot(msleep) +
aes(x = brainwt, y = sleep_total) +
geom_point(color = "grey40") +
# Circles around highlighted points + labels
geom_point(size = 3, shape = 1, color = "grey40", data = ex_points) +
ggrepel::geom_text_repel(aes(label = name), data = ex_points) +
# Use log scaling on xaxis
scale_x_log10(breaks = c(.001, .01, .1, 1)) +
labs(x = lab_lines$brain_log, y = lab_lines$sleep_raw)
As a child growing up on a dairy farm :cow:, it was remarkable to me how little I saw cows sleeping, compared to dogs or cats. Were they okay? Are they constantly tired and groggy? Maybe they are asleep when I’m asleep? Here, it looks like they just don’t need very much sleep.
Next, let’s fit a classical regression model. We will use a logscaled sleep measure so that the regression line doesn’t imply negative sleep (even though brains never get that large).
m1_classical < lm(log_sleep_total ~ log_brainwt, data = msleep)
arm::display(m1_classical)
#> lm(formula = log_sleep_total ~ log_brainwt, data = msleep)
#> coef.est coef.se
#> (Intercept) 0.74 0.04
#> log_brainwt 0.13 0.02
#> 
#> n = 56, k = 2
#> residual sd = 0.17, RSquared = 0.40
We can interpret the model in the usual way: A mammal with 1 kg (0 logkg) of brain mass sleeps 10^{0.74} = 5.5 hours per day. A mammal with a tenth of that brain mass (1 logkg) sleeps 10^{0.74 + 0.13} = 7.4 hours.
We illustrate the regression results to show the predicted mean of y and
its 95% confidence interval. This task is readily accomplished in ggplot2 using
stat_smooth()
. This function fits a model and plots the mean and CI for each
aesthetic grouping of data^{1} in a plot.
ggplot(msleep) +
aes(x = log_brainwt, y = log_sleep_total) +
geom_point() +
stat_smooth(method = "lm", level = .95) +
scale_x_continuous(labels = function(x) 10 ^ x) +
labs(x = lab_lines$brain_log, y = lab_lines$sleep_log)
This interval conveys some uncertainty in the estimate of the mean, but this interval has a frequentist interpretation which can be unintuitive for this sort of data.
Now, for the point of this post: What’s the Bayesian version of this kind of visualization? Specifically, we want to illustrate:
 Predictions from a regression model
 Some uncertainty about those predictions
 Raw data used to train the model
Option 1: The pileoflines plot
The regression line in the classical plot is just one particular line. It’s the line of best fit that satisfies a leastsquares or maximumlikelihood objective. Our Bayesian model estimates an entire distribution of plausible regression lines. The first way to visualize our uncertainty is to plot our own line of best fit along with a sample of other lines from the posterior distribution of the model.
First, we fit a model RStanARM using weakly informative priors.
library("rstanarm")
m1 < stan_glm(
log_sleep_total ~ log_brainwt,
family = gaussian(),
data = msleep,
prior = normal(0, 3),
prior_intercept = normal(0, 3))
We now have 4,000 credible regressions lines for our data.
summary(m1)
#> stan_glm(formula = log_sleep_total ~ log_brainwt, family = gaussian(),
#> data = msleep, prior = normal(0, 3), prior_intercept = normal(0,
#> 3))
#>
#> Family: gaussian (identity)
#> Algorithm: sampling
#> Posterior sample size: 4000
#> Observations: 56
#>
#> Estimates:
#> mean sd 2.5% 25% 50% 75% 97.5%
#> (Intercept) 0.7 0.0 0.6 0.7 0.7 0.8 0.8
#> log_brainwt 0.1 0.0 0.2 0.1 0.1 0.1 0.1
#> sigma 0.2 0.0 0.1 0.2 0.2 0.2 0.2
#> mean_PPD 1.0 0.0 0.9 0.9 1.0 1.0 1.0
#> logposterior 12.0 1.2 9.0 11.5 12.3 12.9 13.4
#>
#> Diagnostics:
#> mcse Rhat n_eff
#> (Intercept) 0.0 1.0 3040
#> log_brainwt 0.0 1.0 3046
#> sigma 0.0 1.0 2862
#> mean_PPD 0.0 1.0 3671
#> logposterior 0.0 1.0 2159
#>
#> For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
For models fit by RStanARM, the generic coefficient function coef()
returns
the median parameter values.
coef(m1)
#> (Intercept) log_brainwt
#> 0.7354829 0.1263922
coef(m1_classical)
#> (Intercept) log_brainwt
#> 0.7363492 0.1264049
We can see that the intercept and slope of the median line is pretty close to the classical model’s intercept and slope. The median line serves as the “point estimate” for our model: If we had to summarize the modeled relationship using just a single number for each parameter, we can use the medians.
One way to visualize our model therefore is to plot our pointestimate line plus a sample of the other credible lines from our model. First, we create a dataframe with all 4,000 regression lines.
# Coercing a model to a dataframe returns dataframe of posterior samples.
# One row per sample.
fits < m1 %>%
as_data_frame %>%
rename(intercept = `(Intercept)`) %>%
select(sigma)
fits
#> # A tibble: 4,000 × 2
#> intercept log_brainwt
#> <dbl> <dbl>
#> 1 0.7529824 0.1369554
#> 2 0.7243708 0.1266290
#> 3 0.7575502 0.1171410
#> 4 0.7855554 0.1031353
#> 5 0.6327073 0.1795992
#> 6 0.6474521 0.1714347
#> 7 0.7512467 0.1155559
#> 8 0.7363273 0.1162038
#> 9 0.7490401 0.1276618
#> 10 0.7238091 0.1305896
#> # ... with 3,990 more rows
We now plot the 500 randomly sampled lines from our model with light, semitransparent lines.
# aesthetic controllers
n_draws < 500
alpha_level < .15
col_draw < "grey60"
col_median < "#3366FF"
ggplot(msleep) +
aes(x = log_brainwt, y = log_sleep_total) +
# Plot a random sample of rows as gray semitransparent lines
geom_abline(aes(intercept = intercept, slope = log_brainwt),
data = sample_n(fits, n_draws), color = col_draw,
alpha = alpha_level) +
# Plot the median values in blue
geom_abline(intercept = median(fits$intercept),
slope = median(fits$log_brainwt),
size = 1, color = col_median) +
geom_point() +
scale_x_continuous(labels = function(x) 10 ^ x) +
labs(x = lab_lines$brain_log, y = lab_lines$sleep_log)
Each of these light lines represents a credible prediction of the mean across the values of x. As these line pile up on top of each other, they create an uncertainty band around our line of best fit. More plausible lines are more likely to be sampled, so these lines overlap and create a uniform color around the median line. As we move left or right, getting farther away from the mean of x, the lines start to fan out and we see very faint individual lines for some of the more extreme (yet still plausible) lines.
The advantage of this plot is that it is a direct visualization of posterior samples—one line per sample. It provides an estimate for the central tendency in the data but it also converys uncertainty around that estimate.
This approach has limitations, however. Lines for subgroups require a little more effort to undo interactions. Also, the regression lines span the whole x axis which is not appropriate when subgroups only use a portion of the xaxis. (This limitation is solvable though.) Finally, I haven’t found good defaults for the aesthetic options: The number of samples, the colors to use, and the transparency level. One can lose of lots and lots and lots time fiddling with those knobs!
Option 2: Mean and its 95% interval
Another option is a direct port of the stat_smooth()
plot: Draw a line of
best fit and the 95% uncertainty interval around it.
To limit the amount of the x axis used by the lines, we’re going to create a sequence of 80 points along the range of the data.
x_rng < range(msleep$log_brainwt)
x_steps < seq(x_rng[1], x_rng[2], length.out = 80)
new_data < data_frame(
observation = seq_along(x_steps),
log_brainwt = x_steps)
new_data
#> # A tibble: 80 × 2
#> observation log_brainwt
#> <int> <dbl>
#> 1 1 3.853872
#> 2 2 3.795509
#> 3 3 3.737146
#> 4 4 3.678784
#> 5 5 3.620421
#> 6 6 3.562058
#> 7 7 3.503695
#> 8 8 3.445332
#> 9 9 3.386970
#> 10 10 3.328607
#> # ... with 70 more rows
The function posterior_linpred()
returns the modelfitted means for a dataframe
of new data. I say means because the function computes 80 predicted means for
each sample from the posterior. The result is 4000 x 80 matrix of fitted means.
pred_lin < posterior_linpred(m1, newdata = new_data)
dim(pred_lin)
#> [1] 4000 80
We are going to reduce this down to just a median and 95% interval around each
point. I do some tidying to get the data into a long format (one row per fitted
mean per posterior sample), and then do a tablejoin with the observation
column included in new_data
. I store these steps in a function because I
have to do them again later in this post.
tidy_predictions < function(mat_pred, df_data, obs_name = "observation",
prob_lwr = .025, prob_upr = .975) {
# Get dataframe with one row per fitted value per posterior sample
df_pred < mat_pred %>%
as_data_frame %>%
setNames(seq_len(ncol(.))) %>%
tibble::rownames_to_column("posterior_sample") %>%
tidyr::gather_(obs_name, "fitted", setdiff(names(.), "posterior_sample"))
df_pred
# Helps with joining later
class(df_pred[[obs_name]]) < class(df_data[[obs_name]])
# Summarise prediction interval for each observation
df_pred %>%
group_by_(obs_name) %>%
summarise(median = median(fitted),
lower = quantile(fitted, prob_lwr),
upper = quantile(fitted, prob_upr)) %>%
left_join(df_data, by = obs_name)
}
df_pred_lin < tidy_predictions(pred_lin, new_data)
df_pred_lin
#> # A tibble: 80 × 5
#> observation median lower upper log_brainwt
#> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 1 1.223770 1.128224 1.320591 3.853872
#> 2 2 1.216516 1.122147 1.311214 3.795509
#> 3 3 1.209222 1.117190 1.301462 3.737146
#> 4 4 1.201831 1.112268 1.291821 3.678784
#> 5 5 1.194506 1.107512 1.282047 3.620421
#> 6 6 1.187240 1.102580 1.272930 3.562058
#> 7 7 1.179955 1.096945 1.263415 3.503695
#> 8 8 1.172608 1.091237 1.254113 3.445332
#> 9 9 1.165268 1.085800 1.244733 3.386970
#> 10 10 1.157932 1.080823 1.235356 3.328607
#> # ... with 70 more rows
We can do the lineplusinterval plot using geom_ribbon()
for the uncertainty
band.
p_linpread < ggplot(msleep) +
aes(x = log_brainwt) +
geom_ribbon(aes(ymin = lower, ymax = upper), data = df_pred_lin,
alpha = 0.4, fill = "grey60") +
geom_line(aes(y = median), data = df_pred_lin, colour = "#3366FF", size = 1) +
geom_point(aes(y = log_sleep_total)) +
scale_x_continuous(labels = function(x) 10 ^ x) +
labs(x = lab_lines$brain_log, y = lab_lines$sleep_log)
p_linpread
This plot is just like the stat_smooth()
plot, except the interval here is
interpreted in terms of postdata probabilities: We’re 95% certain—given the
data, model and our prior information—that the “true” average sleep duration
is contained in this interval. I put “true” in quotes because this is truth in
the “small world” of the model, to quote
Statistical Rethinking, not
necessarily the real world.
Although the interpretation of the interval changes (compared to a classical
confidence interval), its location barely changes at all. If we overlay a
stat_smooth()
layer onto this plot, we can see that two sets of intervals are
virtually identical. With this much data and for this simple of a model, both
types of models can make very similar estimates.
p_linpread + stat_smooth(aes(y = log_sleep_total), method = "lm")
The previous plot illustrates one limitation of this approach: Pragmatically
speaking, stat_smooth()
basically does the same thing, and we’re
not taking advantage of the affordances provided by our model. This is why
RStanARM, in a kind of amusing way, disowns posterior_linpred()
in its
documentation:
This function is occasionally convenient, but it should be used sparingly. Inference and model checking should generally be carried out using the posterior predictive distribution (see
posterior_predict
).
Occasionally convenient. :open_mouth: And elsewhere:
See also:
posterior_predict
to draw from the posterior predictive distribution of the outcome, which is almost always preferable.
Option 3: Mean and 95% interval for modelgenerated data
The reason why posterior_predict()
is preferable is that it uses more
information from our model, namely the error term sigma
.
poseterior_linpred()
predicts averages; posterior_predict()
predicts new
observations. This posterior predictive checking helps us confirm whether our
model—a story of how the data could have been generated—can produce new data
that resembles our data.
Here, we can use the function we defined earlier to get prediction intervals.
# Still a matrix with one row per posterior draw and one column per observation
pred_post < posterior_predict(m1, newdata = new_data)
dim(pred_post)
#> [1] 4000 80
df_pred_post < tidy_predictions(pred_post, new_data)
df_pred_post
#> # A tibble: 80 × 5
#> observation median lower upper log_brainwt
#> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 1 1.224866 0.8685090 1.577798 3.853872
#> 2 2 1.207392 0.8395285 1.560691 3.795509
#> 3 3 1.209352 0.8499785 1.569175 3.737146
#> 4 4 1.203873 0.8333415 1.563349 3.678784
#> 5 5 1.204020 0.8537000 1.554171 3.620421
#> 6 6 1.183633 0.8284588 1.552674 3.562058
#> 7 7 1.182420 0.8234048 1.549418 3.503695
#> 8 8 1.177556 0.8111187 1.543201 3.445332
#> 9 9 1.164234 0.8238208 1.524496 3.386970
#> 10 10 1.161509 0.8130019 1.526353 3.328607
#> # ... with 70 more rows
And we can plot the interval in the same way.
ggplot(msleep) +
aes(x = log_brainwt) +
geom_ribbon(aes(ymin = lower, ymax = upper), data = df_pred_post,
alpha = 0.4, fill = "grey60") +
geom_line(aes(y = median), data = df_pred_post, colour = "#3366FF", size = 1) +
geom_point(aes(y = log_sleep_total)) +
scale_x_continuous(labels = function(x) 10 ^ x) +
labs(x = lab_lines$brain_log, y = lab_lines$sleep_log)
First, we can appreciate that this interval is much wider. That’s because the interval doesn’t summarize a particular statistic (like an average) but all of the observations that can generated by our model. Okay, not all of the observations—just the 95% most probable observations.
Next, we can also appreciate that the line and the ribbon are jagged due to simulation randomness. Each prediction is a random number draw, and at each value of x, we have 4000 such random draws. We computed a median and 95% interval at each x, but due to randomness from simulating new data, these medians do not smoothly connect together in the plot. That’s okay, because these fluctuations are relatively small.
Finally, we can see that there are only two points outside of the interval. These appear to be the restless roe deer and the eversleepy giant armadillo. These two represent the main outliers for our model because they fall slight outside of the 95% prediction interval. In this way, the posterior predictive interval can help us discover which data points are relative outliers for our model.
(Maybe outliers isn’t the right word. It makes perfect sense that 2/56 = 3.6% of the observations fall outside of the 95% interval.)
#> # A tibble: 5 × 4
#> name sleep_total log_sleep_total log_brainwt
#> <chr> <dbl> <dbl> <dbl>
#> 1 Roe deer 3.0 0.4771213 1.0078885
#> 2 Goat 5.3 0.7242759 0.9393022
#> 3 Dog 10.1 1.0043214 1.1549020
#> 4 Patas monkey 10.9 1.0374265 0.9393022
#> 5 Giant armadillo 18.1 1.2576786 1.0915150
This posterior prediction plot does reveal a shortcoming of our model, when plotted in a different manner.
last_plot() +
geom_hline(yintercept = log10(24), color = "grey50") +
geom_label(x = 0, y = log10(24), label = "24 hours")
One faulty consequence of how our model was specified is that it predicts that some mammals sleep more than 24 hours per day—oh, what a life to live :sleeping:.
Wrap up
In the post, I covered three different ways to plot the results of an RStanARM model, while demonstrating some of the key functions for working with RStanARM models. Time well spent, I think.
As for future directions, I learned about the underdevelopment (as of November 2016) R package bayesplot by the Stan team. The README package shows off a lot of different ways to visualize posterior samples from a model. I’ll be sure to demo it on this dataset once it goes live.

That is, if we map the plot’s color aesthetic to a categorical variable in the data,
stat_smooth()
will fit a separate model for each color/category. I figured this out when I tried to write my own functionstat_smooth_stan()
based on ggplot2’s extensions vignette and noticed that RStanARM was printing out MCMC sampling information for each color/category of the data. ↩
Rbloggers.com offers daily email updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/datascience job.
Want to share your content on Rbloggers? click here if you have a blog, or here if you don't.