Animating the Metropolis algorithm

[This article was first published on Ecology in silico, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

The Metropolis algorithm, and its generalization (Metropolis-Hastings algorithm) provide elegant methods for obtaining sequences of random samples from complex probability distributions. When I first read about modern MCMC methods, I had trouble visualizing the convergence of Markov chains in higher dimensional cases. So, I thought I might put together a visualization in a two-dimensional case.

I’ll use a simple example: estimating a population mean and standard deviation. We’ll define some population level parameters, collect some data, then use the Metropolis algorithm to simulate the joint posterior of the mean and standard deviation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
<span class="line"><span class="c1"># population level parameters</span>
</span><span class="line">mu <span class="o"><-</span> <span class="m">7</span>
</span><span class="line">sigma <span class="o"><-</span> <span class="m">3</span>
</span><span class="line">
</span><span class="line"><span class="c1"># collect some data (e.g. a sample of heights)</span>
</span><span class="line">n <span class="o"><-</span> <span class="m">50</span>
</span><span class="line">x <span class="o"><-</span> rnorm<span class="p">(</span>n<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">)</span>
</span><span class="line">
</span><span class="line"><span class="c1"># log-likelihood function</span>
</span><span class="line">ll <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> muhat<span class="p">,</span> sigmahat<span class="p">){</span>
</span><span class="line">  sum<span class="p">(</span>dnorm<span class="p">(</span>x<span class="p">,</span> muhat<span class="p">,</span> sigmahat<span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">))</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line"><span class="c1"># prior densities</span>
</span><span class="line">pmu <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>mu<span class="p">){</span>
</span><span class="line">  dnorm<span class="p">(</span>mu<span class="p">,</span> <span class="m">0</span><span class="p">,</span> <span class="m">100</span><span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">)</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line">psigma <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>sigma<span class="p">){</span>
</span><span class="line">  dunif<span class="p">(</span>sigma<span class="p">,</span> <span class="m">0</span><span class="p">,</span> <span class="m">10</span><span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">)</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line"><span class="c1"># posterior density function (log scale)</span>
</span><span class="line">post <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">){</span>
</span><span class="line">  ll<span class="p">(</span>x<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">)</span> <span class="o">+</span> pmu<span class="p">(</span>mu<span class="p">)</span> <span class="o">+</span> psigma<span class="p">(</span>sigma<span class="p">)</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line">geninits <span class="o"><-</span> <span class="kr">function</span><span class="p">(){</span>
</span><span class="line">  list<span class="p">(</span>mu <span class="o">=</span> runif<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">4</span><span class="p">,</span> <span class="m">10</span><span class="p">),</span>
</span><span class="line">       sigma <span class="o">=</span> runif<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">6</span><span class="p">))</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line">jump <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> dist <span class="o">=</span> <span class="m">.1</span><span class="p">){</span> <span class="c1"># must be symmetric</span>
</span><span class="line">  x <span class="o">+</span> rnorm<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">0</span><span class="p">,</span> dist<span class="p">)</span>
</span><span class="line"><span class="p">}</span>
</span><span class="line">
</span><span class="line">iter <span class="o">=</span> <span class="m">10000</span>
</span><span class="line">chains <span class="o"><-</span> <span class="m">3</span>
</span><span class="line">posterior <span class="o"><-</span> array<span class="p">(</span>dim <span class="o">=</span> c<span class="p">(</span>chains<span class="p">,</span> <span class="m">2</span><span class="p">,</span> iter<span class="p">))</span>
</span><span class="line">accepted <span class="o"><-</span> array<span class="p">(</span>dim<span class="o">=</span>c<span class="p">(</span>chains<span class="p">,</span> iter <span class="o">-</span> <span class="m">1</span><span class="p">))</span>
</span><span class="line">
</span><span class="line"><span class="kr">for</span> <span class="p">(</span>c <span class="kr">in</span> <span class="m">1</span><span class="o">:</span>chains<span class="p">){</span>
</span><span class="line">  theta.post <span class="o"><-</span> array<span class="p">(</span>dim<span class="o">=</span>c<span class="p">(</span><span class="m">2</span><span class="p">,</span> iter<span class="p">))</span>
</span><span class="line">  inits <span class="o"><-</span> geninits<span class="p">()</span>
</span><span class="line">  theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> inits<span class="o">$</span>mu
</span><span class="line">  theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> inits<span class="o">$</span>sigma
</span><span class="line">  <span class="kr">for</span> <span class="p">(</span>t <span class="kr">in</span> <span class="m">2</span><span class="o">:</span>iter<span class="p">){</span>
</span><span class="line">    <span class="c1"># theta_star = proposed next values for parameters</span>
</span><span class="line">    theta_star <span class="o"><-</span> c<span class="p">(</span>jump<span class="p">(</span>theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> t<span class="m">-1</span><span class="p">]),</span> jump<span class="p">(</span>theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> t<span class="m">-1</span><span class="p">]))</span>
</span><span class="line">    pstar <span class="o"><-</span> post<span class="p">(</span>x<span class="p">,</span> mu <span class="o">=</span> theta_star<span class="p">[</span><span class="m">1</span><span class="p">],</span> sigma <span class="o">=</span> theta_star<span class="p">[</span><span class="m">2</span><span class="p">])</span>
</span><span class="line">    pprev <span class="o"><-</span> post<span class="p">(</span>x<span class="p">,</span> mu <span class="o">=</span> theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> t<span class="m">-1</span><span class="p">],</span> sigma <span class="o">=</span> theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> t<span class="m">-1</span><span class="p">])</span>
</span><span class="line">    lr <span class="o"><-</span> pstar <span class="o">-</span> pprev
</span><span class="line">    r <span class="o"><-</span> exp<span class="p">(</span>lr<span class="p">)</span>
</span><span class="line">
</span><span class="line">    <span class="c1"># theta_star is accepted if posterior density is higher w/ theta_star</span>
</span><span class="line">    <span class="c1"># if posterior density is not higher, it is accepted with probability r</span>
</span><span class="line">    <span class="c1"># else theta does not change from time t-1 to t</span>
</span><span class="line">    accept <span class="o"><-</span> rbinom<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> prob <span class="o">=</span> min<span class="p">(</span>r<span class="p">,</span> <span class="m">1</span><span class="p">))</span>
</span><span class="line">    accepted<span class="p">[</span>c<span class="p">,</span> t <span class="o">-</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> accept
</span><span class="line">    <span class="kr">if</span> <span class="p">(</span>accept <span class="o">==</span> <span class="m">1</span><span class="p">){</span>
</span><span class="line">      theta.post<span class="p">[,</span> t<span class="p">]</span> <span class="o"><-</span> theta_star
</span><span class="line">    <span class="p">}</span> <span class="kr">else</span> <span class="p">{</span>
</span><span class="line">      theta.post<span class="p">[,</span> t<span class="p">]</span> <span class="o"><-</span> theta.post<span class="p">[,</span> t<span class="m">-1</span><span class="p">]</span>
</span><span class="line">    <span class="p">}</span>
</span><span class="line">  <span class="p">}</span>
</span><span class="line">  posterior<span class="p">[</span>c<span class="p">,</span> <span class="p">,</span> <span class="p">]</span> <span class="o"><-</span> theta.post
</span><span class="line"><span class="p">}</span>
</span>

Then, to visualize the evolution of the Markov chains, we can make plots of the chains in 2-parameter space, along with the posterior density at different iterations, joining these plots together using ImageMagick (in the terminal) to create an animated .gif:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
<span class="line">require<span class="p">(</span>sm<span class="p">)</span>
</span><span class="line">seq1 <span class="o"><-</span> seq<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">300</span><span class="p">,</span> by<span class="o">=</span><span class="m">5</span><span class="p">)</span>
</span><span class="line">seq2 <span class="o"><-</span> seq<span class="p">(</span><span class="m">300</span><span class="p">,</span> <span class="m">500</span><span class="p">,</span> by<span class="o">=</span><span class="m">10</span><span class="p">)</span>
</span><span class="line">seq3 <span class="o"><-</span> seq<span class="p">(</span><span class="m">500</span><span class="p">,</span> iter<span class="p">,</span> by<span class="o">=</span><span class="m">300</span><span class="p">)</span>
</span><span class="line">sequence <span class="o"><-</span> c<span class="p">(</span>seq1<span class="p">,</span> seq2<span class="p">,</span> seq3<span class="p">)</span>
</span><span class="line">length<span class="p">(</span>sequence<span class="p">)</span>
</span><span class="line">
</span><span class="line">xlims <span class="o"><-</span> c<span class="p">(</span><span class="m">4</span><span class="p">,</span> <span class="m">10</span><span class="p">)</span>
</span><span class="line">ylims <span class="o"><-</span> c<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">6</span><span class="p">)</span>
</span><span class="line">
</span><span class="line">dir.create<span class="p">(</span><span class="s">"metropolis_ex"</span><span class="p">)</span>
</span><span class="line">setwd<span class="p">(</span><span class="s">"metropolis_ex"</span><span class="p">)</span>
</span><span class="line">
</span><span class="line">png<span class="p">(</span>file <span class="o">=</span> <span class="s">"metrop%03d.png"</span><span class="p">,</span> width<span class="o">=</span><span class="m">700</span><span class="p">,</span> height<span class="o">=</span><span class="m">350</span><span class="p">)</span>
</span><span class="line">  <span class="kr">for</span> <span class="p">(</span>i <span class="kr">in</span> sequence<span class="p">){</span>
</span><span class="line">    par<span class="p">(</span>mfrow<span class="o">=</span>c<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">))</span>
</span><span class="line">    plot<span class="p">(</span>posterior<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span>
</span><span class="line">         type<span class="o">=</span><span class="s">"l"</span><span class="p">,</span> xlim<span class="o">=</span>xlims<span class="p">,</span> ylim<span class="o">=</span>ylims<span class="p">,</span> col<span class="o">=</span><span class="s">"blue"</span><span class="p">,</span>
</span><span class="line">         xlab<span class="o">=</span><span class="s">"mu"</span><span class="p">,</span> ylab<span class="o">=</span><span class="s">"sigma"</span><span class="p">,</span> main<span class="o">=</span><span class="s">"Markov chains"</span><span class="p">)</span>
</span><span class="line">    lines<span class="p">(</span>posterior<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span>
</span><span class="line">          col<span class="o">=</span><span class="s">"purple"</span><span class="p">)</span>
</span><span class="line">    lines<span class="p">(</span>posterior<span class="p">[</span><span class="m">3</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">3</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span>
</span><span class="line">          col<span class="o">=</span><span class="s">"red"</span><span class="p">)</span>
</span><span class="line">    text<span class="p">(</span>x<span class="o">=</span><span class="m">7</span><span class="p">,</span> y<span class="o">=</span><span class="m">1.2</span><span class="p">,</span> paste<span class="p">(</span><span class="s">"Iteration "</span><span class="p">,</span> i<span class="p">),</span> cex<span class="o">=</span><span class="m">1.5</span><span class="p">)</span>
</span><span class="line">    sm.density<span class="p">(</span>x<span class="o">=</span>cbind<span class="p">(</span>c<span class="p">(</span>posterior<span class="p">[,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">]),</span> c<span class="p">(</span>posterior<span class="p">[,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">])),</span>
</span><span class="line">               xlab<span class="o">=</span><span class="s">"mu"</span><span class="p">,</span> ylab<span class="o">=</span><span class="s">"sigma"</span><span class="p">,</span>
</span><span class="line">               zlab<span class="o">=</span><span class="s">""</span><span class="p">,</span> zlim<span class="o">=</span>c<span class="p">(</span><span class="m">0</span><span class="p">,</span> <span class="m">.7</span><span class="p">),</span>
</span><span class="line">               xlim<span class="o">=</span>xlims<span class="p">,</span> ylim<span class="o">=</span>ylims<span class="p">,</span> col<span class="o">=</span><span class="s">"white"</span><span class="p">)</span>
</span><span class="line">    title<span class="p">(</span><span class="s">"Posterior density"</span><span class="p">)</span>
</span><span class="line">  <span class="p">}</span>
</span><span class="line">dev.off<span class="p">()</span>
</span><span class="line">system<span class="p">(</span><span class="s">"convert -delay 15 *.png metrop.gif"</span><span class="p">)</span>
</span><span class="line">file.remove<span class="p">(</span>list.files<span class="p">(</span>pattern<span class="o">=</span><span class="s">".png"</span><span class="p">))</span>
</span>

To leave a comment for the author, please follow the link and comment on their blog: Ecology in silico.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)