JAGS and Stan

August 24, 2014
By

(This article was first published on Wiekvoet, and kindly contributed to R-bloggers)

During the last year I have been running some estimations in both JAGS and Stan. In that period I have seen one example where JAGS could not get me decent samples (in the sense of low Rhat and high number of effective samples) but that was data which I could not blog about. When two weeks ago I had a problem where part of my model did not converge well in JAGS I wondered how Stan would fare. Hence this post. It appears that Stan did not really do much better. What did appear is that results in this kind of difficult problem can vary depending on the inits and random samples used in the chain. This probably means more samples helps, but that is not the topic of this post.

Programs

In effect I expect most readers of this blog to know about both Stan and JAGS, but a few lines about them seem not amiss. Stan and JAGS can be used for the same kind of problems, but they are quite different. JAGS is a variation on BUGS, similar to WinBUGS and OpenBUGS, where a model states just relations between variables. Stan on the other hand, is a program where a model has clearly defined parts, where order of statements is of influence. Stan is compiled, which takes some time by itself. Both Stan and BUGS can be run by themselves, but I find it most convenient to run them from R. R is then used for pre-processing data, setting up the model and finally summarizing the samples. Because JAGS and Stan are so different, they need completely different number of MCMC samples. Stan is supposed to be more efficient, hence needing less samples to obtain a posterior result of similar quality.
From a model development point of view, JAGS (rjags, R2jags) is slightly more integrated in R than Stan (Rstan), mostly because JAGS models pretend to be R models, which means my editor will lend a hand, while Rstan has its model just in a text vector. In addition, JAGS has no compilation time. The plus of Stan though is highly organized model code.

Models

The model describes the number of shootings per state, hierarchically under regions. This means there is a binomial probability of interest, the states, under beta distributed regions. The beta has uninformative priors. After some tweaking the models should be equivalent. This means that the JAGS model is slightly different from previous posts. The number of samples chosen is 250000 with 100000 burn-in for JAGS and 4000 with half burn-in for Stan. I have chosen for ten chains. Usually I would use four, but since I suspected some chains to misbehave, I opted for a larger number. The inits were either around 1000, which means that a number of parameters have to shift quite a bit to get beta near 1 in 100000 or close to the that distribution, which means the parameters mostly have to converge the regions and states to the correct values. In terms of model behavior I only look at the priors and hyperpriors. Especially a and b from the beta distribution (state level) are difficult to estimate, while their ratio and state level estimates are quite easy.

Results

What I expected to write here is that Stan was coping a bit better, especially when the inits are a bit off. Which is what happened in the first version of the post. But then I did an additional calculation and Stan got worse results too. So, part of the conclusion is that it is very dependent on the inits and the random sampling in the MCMC chain.

Speed

In terms of speed, Stan has the property that different chains have markedly different speeds. One chain can take 90 seconds while the next takes 250 seconds. In JAGS individual chains progress is not displayed, so no information there.
In general speeds were about the same, 1000 to 1800 seconds. If that seems large, this was on a Celeron processor with one core used. MCMC chains are embarrassingly parallel, so gains can be made easy.

Gelman diagnostic

This is the diagnostic calculated in coda so the diagnostics are comparable. There are eight series of values in the figure. Each pair is for one model, where coda actually gives both point estimate and 95% upper limit of the estimate. Smart directs to the better inits, 1000 to inits around 1000. The x axis refers to the parameters estimated. There is something odd at variables 19, 20 and 28, 29. Just prior to posting I discovered that variables, especially aa and bb are sorted differently compared to the other variables in Stan than JAGS. In hindsight I should have put my parameters in alphabetical order.
The plot shows that Stan is actually doing a bit less than JAGS especially with the inits which should have made correct results more easy.
 

Effective number of samples

Again the plot is made using calculations in coda so the numbers are comparable. In number of effective samples it seems Stan is doing a bit better for the more difficult parameters. For the easy parameters JAGS is a bit better, here the large number of samples for JAGS pays off..

Conclusion

Neither JAGS nor Stan came out clearly on top, which was not as I expected. Nevertheless, it still seems that while JAGS is my tool for simple models, while Stan is the choice for more complex models. Conversion from JAGS to Stan was not difficult.

Output from runs

stan & inits1

    user   system  elapsed
1992.725    2.666 2016.079


Inference for Stan model: model1.
10 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=20000.

                   mean    se_mean          sd    n_eff Rhat
a[1]              53.44       4.38       44.95 NA   105 1.10
a[2]              65.07       7.55       53.74 NA    51 1.10
a[3]              66.92       6.93       54.54 NA    62 1.10
a[4]              67.28       7.00       55.38 NA    63 1.11
a[5]              62.51       4.99       50.49 NA   102 1.10
a[6]              70.32       8.18       59.19 NA    52 1.11
a[7]              62.44       7.46       53.44 NA    51 1.10
a[8]              59.37       6.37       49.08 NA    59 1.11
a[9]              63.54       5.29       53.66 NA   103 1.10
b[1]        44661405.38 3638329.39 37093869.18 NA   104 1.10
b[2]        39099064.06 4007140.91 32542035.41 NA    66 1.10
b[3]        38335546.03 4378451.81 31386394.46 NA    51 1.11
b[4]        38034644.61 3133779.95 31056909.88 NA    98 1.10
b[5]        39443027.62 4038448.85 32265357.99 NA    64 1.10
b[6]        35355902.63 2933955.81 29177099.95 NA    99 1.10
b[7]        41237192.06 4561471.29 33706574.62 NA    55 1.10
b[8]        42188390.23 3364049.93 34756182.10 NA   107 1.10
b[9]        39656654.17 4301318.67 32155276.73 NA    56 1.10
betamean[1]        0.08       0.00        0.02 NA   793 1.01
betamean[2]        0.11       0.00        0.02 NA  1253 1.01
betamean[3]        0.11       0.00        0.01 NA  1436 1.01
betamean[4]        0.11       0.00        0.02 NA  1608 1.01
betamean[5]        0.10       0.00        0.02 NA   947 1.01
betamean[6]        0.12       0.00        0.02 NA   683 1.01
betamean[7]        0.09       0.00        0.02 NA   759 1.01
betamean[8]        0.09       0.00        0.01 NA   717 1.01
betamean[9]        0.10       0.00        0.02 NA   737 1.02
aa                63.61       6.74       52.03 NA    60 1.11
bb          39230839.48 3945047.13 31100418.06 NA    62 1.11
sda               11.71       1.18       13.79 NA   137 1.07
sdb          6716646.80  546142.87  8203368.43 NA   226 1.04
lp__           -7555.39       2.81       24.19 NA    74 1.10

Samples were drawn using NUTS(diag_e) at Sun Aug 24 10:32:29 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).


Stan & inits2

    user   system  elapsed
1039.834    1.779 1042.681 

Inference for Stan model: model1.
10 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=20000.

                   mean    se_mean          sd    n_eff Rhat
a[1]              46.65       2.96       40.52 NA   188 1.04
a[2]              56.89       3.51       47.84 NA   186 1.03
a[3]              59.52       3.83       50.16 NA   171 1.04
a[4]              58.93       3.70       50.71 NA   188 1.04
a[5]              54.94       3.32       45.85 NA   190 1.03
a[6]              61.79       3.96       53.71 NA   184 1.04
a[7]              53.75       3.40       46.37 NA   186 1.04
a[8]              50.94       3.22       44.15 NA   187 1.04
a[9]              54.71       3.52       47.86 NA   185 1.04
b[1]        38295125.50 2367547.52 32822125.24 NA   192 1.04
b[2]        34100563.26 2135975.40 29277394.26 NA   188 1.04
b[3]        33605901.88 2089826.28 28387357.36 NA   185 1.04
b[4]        33282843.68 2064977.85 27983832.94 NA   184 1.04
b[5]        34582670.15 2189638.25 29708184.70 NA   184 1.04
b[6]        31320186.63 1969915.12 26631419.99 NA   183 1.04
b[7]        35911817.86 2246565.24 30569577.66 NA   185 1.04
b[8]        36698348.20 2277791.00 31350133.96 NA   189 1.03
b[9]        34452421.07 2110620.24 28624582.07 NA   184 1.04
betamean[1]        0.08       0.00        0.02 NA   531 1.02
betamean[2]        0.11       0.00        0.02 NA   935 1.01
betamean[3]        0.11       0.00        0.01 NA   377 1.02
betamean[4]        0.11       0.00        0.02 NA  1667 1.01
betamean[5]        0.10       0.00        0.02 NA   512 1.02
betamean[6]        0.12       0.00        0.02 NA   791 1.01
betamean[7]        0.09       0.00        0.01 NA  1239 1.01
betamean[8]        0.09       0.00        0.01 NA   426 1.02
betamean[9]        0.10       0.00        0.01 NA   704 1.01
aa                55.36       3.47       46.59 NA   181 1.04
bb          34503256.80 2129396.22 28613633.55 NA   181 1.04
sda               10.07       0.71       11.42 NA   256 1.03
sdb          5272208.87  372540.69  6542979.18 NA   308 1.02
lp__           -7557.13       1.74       23.38 NA   181 1.03

Samples were drawn using NUTS(diag_e) at Sun Aug 24 10:56:17 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).


JAGS inits 1

    user   system  elapsed
1784.454    1.222 1791.519

Inference for Bugs model at "/tmp/Rtmp2VW032/model75d4f74fe01.txt", fit using jags,
 10 chains, each with 250000 iterations (first 1e+05 discarded), n.thin = 150
 n.sims = 10000 iterations saved
                 mu.vect      sd.vect int.matrix  Rhat n.eff
a[1]              52.059       43.877         NA 1.099    65
a[2]              62.548       51.117         NA 1.102    63
a[3]              64.949       53.088         NA 1.104    62
a[4]              65.120       54.238         NA 1.104    62
a[5]              61.225       50.014         NA 1.103    63
a[6]              67.498       56.420         NA 1.102    63
a[7]              59.181       49.137         NA 1.102    63
a[8]              57.787       47.902         NA 1.103    63
a[9]              61.403       51.773         NA 1.102    63
aa                60.755       49.814         NA 1.099    65
b[1]        42704592.037 35622401.674         NA 1.104    62
b[2]        37598265.554 31204530.763         NA 1.100    64
b[3]        36937885.481 30057853.515         NA 1.104    62
b[4]        36766568.933 30153113.256         NA 1.103    63
b[5]        38140646.333 31358488.234         NA 1.103    63
b[6]        34271877.336 28059552.425         NA 1.103    63
b[7]        39747380.703 32646090.587         NA 1.103    63
b[8]        40262327.803 33127265.445         NA 1.102    63
b[9]        38041673.881 30717741.484         NA 1.104    62
bb          37696463.522 30374989.608         NA 1.104    62
betamean[1]        0.079        0.018         NA 1.001  6100
betamean[2]        0.108        0.017         NA 1.005  1200
betamean[3]        0.111        0.012         NA 1.001 10000
betamean[4]        0.111        0.018         NA 1.002  4300
betamean[5]        0.103        0.015         NA 1.002  3200
betamean[6]        0.122        0.016         NA 1.004  1700
betamean[7]        0.094        0.015         NA 1.001 10000
betamean[8]        0.090        0.014         NA 1.002  4300
betamean[9]        0.098        0.015         NA 1.015   420
sda               10.819       15.196         NA 1.043   140
sdb          6136772.765  8870298.491         NA 1.046   130
deviance         266.189       18.534         NA 1.083    76

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 151.1 and DIC = 417.3
DIC is an estimate of expected predictive error (lower deviance is better).

jags & ints 2

user system elapsed
1854.296 0.863 1856.972

Inference for Bugs model at "/tmp/Rtmp2VW032/model75d1dda784d.txt", fit using jags,
 10 chains, each with 250000 iterations (first 1e+05 discarded), n.thin = 150
 n.sims = 10000 iterations saved
                 mu.vect      sd.vect int.matrix  Rhat n.eff
a[1]              62.153       46.590         NA 1.063    98
a[2]              74.562       53.977         NA 1.063    98
a[3]              77.974       56.111         NA 1.067    92
a[4]              77.895       56.781         NA 1.067    93
a[5]              73.379       53.374         NA 1.065    94
a[6]              80.658       59.219         NA 1.065    95
a[7]              71.378       52.600         NA 1.062    99
a[8]              69.329       50.894         NA 1.065    95
a[9]              73.640       55.117         NA 1.064    97
aa                72.389       52.305         NA 1.062    99
b[1]        51494466.291 38425814.261         NA 1.065    95
b[2]        44959657.895 33022305.640         NA 1.062    99
b[3]        44386799.990 31879332.379         NA 1.066    94
b[4]        43857406.601 31539586.787         NA 1.066    94
b[5]        45956962.745 33781483.665         NA 1.066    94
b[6]        40856777.369 29453791.246         NA 1.065    95
b[7]        47575125.880 34573426.507         NA 1.063    98
b[8]        48634592.709 35656103.487         NA 1.066    94
b[9]        45324100.091 32444527.795         NA 1.063    97
bb          44759762.836 31404548.768         NA 1.062    99
betamean[1]        0.078        0.018         NA 1.002  4500
betamean[2]        0.107        0.016         NA 1.004  1600
betamean[3]        0.111        0.012         NA 1.001  9700
betamean[4]        0.112        0.017         NA 1.002  4200
betamean[5]        0.102        0.014         NA 1.002  3500
betamean[6]        0.123        0.015         NA 1.004  1800
betamean[7]        0.094        0.015         NA 1.001 10000
betamean[8]        0.090        0.014         NA 1.001  7800
betamean[9]        0.099        0.014         NA 1.013   470
sda               14.300       21.320         NA 1.027   220
sdb          8080234.258 12065450.388         NA 1.027   220
deviance         269.866       17.631         NA 1.053   120

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 143.4 and DIC = 413.2
DIC is an estimate of expected predictive error (lower deviance is better).

Code

Reading data

r13 <- readLines('raw13.txt')
r14 <- readLines('raw14.txt')
r1 <- c(r13,r14)
r2 <- gsub('\[[a-zA-Z0-9]*\]','',r1)
r3 <- gsub('^ *$','',r2)
r4 <- r3[r3!='']
r5 <- gsub('\t$','',r4)
r6 <- gsub('\t References$','',r5)
r7 <- read.table(textConnection(r6),
    sep='t',
    header=TRUE,
    stringsAsFactors=FALSE)
r7$Location[r7$Location=='Washington DC'] <-
    'WashingtonDC, DC'
r8 <- read.table(textConnection(as.character(r7$Location)),
    sep=',',
    col.names=c('Location','State'),
    stringsAsFactors=FALSE)
r8$State <- gsub(' ','',r8$State)
r8$State[r8$State=='Tennessee'] <- 'TN'
r8$State[r8$State=='Ohio'] <- 'OH'
r8$State[r8$State %in% c('Kansas','KA')] <- 'KS'
r8$State[r8$State=='Louisiana'] <- 'LA'
r8$State[r8$State=='Illinois'] <- 'IL'
r8$State <- toupper(r8$State)
table(r8$State)
r7$StateAbb <- r8$State
r7$Location <- r8$Location
r7 <- r7[! (r7$State %in% c( 'PUERTORICO','PR')),]
r7$Date <- gsub('/13$','/2013',r7$Date)
r7$date <- as.Date(r7$Date,format="%m/%d/%Y")

states <- data.frame(
    StateAbb=as.character(state.abb),
    StateRegion=state.division,
    State=as.character(state.name)
)
states <- rbind(states,data.frame(StateAbb='DC',
        State='District of Columbia',
        StateRegion= 'Middle Atlantic'))
# http://www.census.gov/popest/data/state/totals/2013/index.html
inhabitants <- read.csv('NST-EST2013-01.treated.csv')
#put it all together

states <- merge(states,inhabitants)
r9 <- merge(r7,states)
#########################

r10 <- merge(as.data.frame(xtabs(~StateAbb,data=r9)),states,all=TRUE)
r10$Freq[is.na(r10$Freq)] <- 0
r10$Incidence <- r10$Freq*100000*365/r10$Population/
    as.numeric((max(r7$date)-min(r7$date)))

Common for modelling

datain <- list(
    count=r10$Freq,
    Population = r10$Population,
    n=nrow(r10),
    nregion =nlevels(r10$StateRegion),
    Region=(1:nlevels(r10$StateRegion))[r10$StateRegion],
    scale=100000*365/
        as.numeric((max(r7$date)-min(r7$date))))


parameters <- c('a','b','betamean','aa','bb','sda','sdb')

inits1 <- function()
    list(a=rnorm(datain$nregion,100,10),
        b=rnorm(datain$nregion,1e8,1e7),
        aa=rnorm(1,100,10),
        sda=rnorm(1,100,10),
        sdb=rnorm(1,1e7,1e6),
        bb=rnorm(1,1e7,1e6),
        p1=rnorm(datain$n,1e-7,1e-8))

inits2 <- function()
    list(a=rnorm(datain$nregion,1000,100),
        b=rnorm(datain$nregion,1000,100),
        aa= rnorm(1,           1000,100),
        sda=rnorm(1,1000,100),
        sdb=rnorm(1,1000,100),
        bb= rnorm(1,1000,100),
        p1=rnorm(datain$n,1e-7,1e-8))

inits1l<- lapply(1:10,function(x) inits1())
inits2l<- lapply(1:10,function(x) inits2())

Stan model

model1 <- '
    data {
         int<lower=0> n;
       int<lower=0> nregion;
       int count[n];
       int<lower=1,upper=nregion> Region[n];
       int Population[n];
       real scale;
        }
    parameters {
       vector <lower=0> [nregion]  a;
       vector <lower=0> [nregion]  b; 
       vector <lower=0,upper=1> [n] p1;
       real <lower=0> aa;
       real <lower=0> bb;
       real <lower=0> sda;
       real <lower=0> sdb;
        }
  model {       
        for (i in 1:n) {
            p1[i] ~ beta(a[Region[i]],b[Region[i]]);
        }
        count ~ binomial(Population,p1);
        a ~ normal(aa,sda);
        b ~ normal(bb,sdb);
        aa ~ uniform(0,1e5);
        bb ~ uniform(0,1e8);
        sda ~ uniform(0,1e5);
       sdb ~ uniform(0,1e8);
        }
  generated quantities {

        vector [n] pp;
        vector [nregion] betamean;
        for (i in 1:nregion) {
           betamean[i] <- scale*a[i]/(a[i]+b[i]);
           }
        pp <- p1 * scale;

        }
'

system.time(fits1 <- stan(model_code = model1,
    data = datain,
    pars=parameters,
    init=inits1l,
    iter = 4000,
    chains = 10))
print(fits1,probs=NA)

system.time(fits2 <- stan(model_code = model1,
        data = datain,
        pars=parameters,
        init=inits2l,
        iter = 4000,
        chains = 10))
print(fits2,probs=NA)

JAGS model

 model.jags <- function() {
    for (i in 1:n) {
        count[i] ~ dbin(p1[i],Population[i])
        p1[i] ~ dbeta(a[Region[i]],b[Region[i]])
        pp[i] <- p1[i]*scale
    }
    for (i in 1:nregion) {
        a[i] ~ dnorm(aa,tauaa) %_% T(0,)
        b[i] ~ dnorm(bb,taubb) %_% T(0,)
        betamean[i] <- scale * a[i]/(a[i]+b[i])
     }
    tauaa <- pow(sda,-2)
    sda ~dunif(0,1e5)
    taubb <- pow(sdb,-2)
    sdb ~dunif(0,1e8)
    aa ~ dunif(0,1e5)
    bb ~ dunif(0,1e8)
}

system.time(jags1 <- jags(datain,
    model=model.jags,
    inits=inits1l,
    parameters=parameters,
    n.iter=250000,
    n.burnin=100000,
    n.chain=10))
print(jags1,intervals=NA)


system.time(jags2 <- jags(datain,
        model=model.jags,
        inits=inits2l,
        parameters=parameters,
        n.iter=250000,
        n.burnin=100000,
        n.chain=10))
print(jags2,intervals=NA)

Post processing

st1coda <- mcmc.list(lapply(1:ncol(fits1),
        function(x) mcmc(as.array(fits1)[,x,])))
st2coda <- mcmc.list(lapply(1:ncol(fits2),
        function(x) mcmc(as.array(fits2)[,x,])))
jg1coda <- as.mcmc(jags1)
jg2coda <- as.mcmc(jags2)

options(width=100)

gdiag <- cbind(gelman.diag(st1coda)[[1]],
    gelman.diag(st2coda)[[1]],
    gelman.diag(jg1coda)[[1]],
    gelman.diag(jg2coda)[[1]])

png('Gelman Diagnostic.png')
matplot(gdiag,ylim=c(1,1.5),ylab='Gelman Diagnostic')
legend(x='topleft',pch=format(c(1:8)),
    c('1 Stan smart point',
'2 Stan smart upper',
'3 Stan  1000 point',
'4 Stan  1000 upper',
'5 Jags smart point',
'6 Jags smart upper',
'7 Jags  1000 point',
'8 Jags  1000 upper'),
col=c(1:6,1,2),ncol=2)
dev.off()

efs <- cbind(effectiveSize(st1coda),
    effectiveSize(st2coda),
    effectiveSize(jg1coda),
    effectiveSize(jg2coda))

png('Efsz.png')
matplot(efs,log='y',ylab='Effective sample size')
legend(x='topleft',pch=format(c(1:8)),
    c('1 Stan smart ',
        '2 Stan  1000 point',
        '3 Jags smart point',
        '4 Jags  1000 point'),
    col=c(1:4))
dev.off()

To leave a comment for the author, please follow the link and comment on his blog: Wiekvoet.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...



If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.