Animating the Metropolis algorithm
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
The Metropolis algorithm, and its generalization (Metropolis-Hastings algorithm) provide elegant methods for obtaining sequences of random samples from complex probability distributions. When I first read about modern MCMC methods, I had trouble visualizing the convergence of Markov chains in higher dimensional cases. So, I thought I might put together a visualization in a two-dimensional case.
I’ll use a simple example: estimating a population mean and standard deviation. We’ll define some population level parameters, collect some data, then use the Metropolis algorithm to simulate the joint posterior of the mean and standard deviation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
<span class="line"><span class="c1"># population level parameters</span> </span><span class="line">mu <span class="o"><-</span> <span class="m">7</span> </span><span class="line">sigma <span class="o"><-</span> <span class="m">3</span> </span><span class="line"> </span><span class="line"><span class="c1"># collect some data (e.g. a sample of heights)</span> </span><span class="line">n <span class="o"><-</span> <span class="m">50</span> </span><span class="line">x <span class="o"><-</span> rnorm<span class="p">(</span>n<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">)</span> </span><span class="line"> </span><span class="line"><span class="c1"># log-likelihood function</span> </span><span class="line">ll <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> muhat<span class="p">,</span> sigmahat<span class="p">){</span> </span><span class="line"> sum<span class="p">(</span>dnorm<span class="p">(</span>x<span class="p">,</span> muhat<span class="p">,</span> sigmahat<span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">))</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line"><span class="c1"># prior densities</span> </span><span class="line">pmu <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>mu<span class="p">){</span> </span><span class="line"> dnorm<span class="p">(</span>mu<span class="p">,</span> <span class="m">0</span><span class="p">,</span> <span class="m">100</span><span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">)</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line">psigma <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>sigma<span class="p">){</span> </span><span class="line"> dunif<span class="p">(</span>sigma<span class="p">,</span> <span class="m">0</span><span class="p">,</span> <span class="m">10</span><span class="p">,</span> log<span class="o">=</span><span class="k-Variable">T</span><span class="p">)</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line"><span class="c1"># posterior density function (log scale)</span> </span><span class="line">post <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">){</span> </span><span class="line"> ll<span class="p">(</span>x<span class="p">,</span> mu<span class="p">,</span> sigma<span class="p">)</span> <span class="o">+</span> pmu<span class="p">(</span>mu<span class="p">)</span> <span class="o">+</span> psigma<span class="p">(</span>sigma<span class="p">)</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line">geninits <span class="o"><-</span> <span class="kr">function</span><span class="p">(){</span> </span><span class="line"> list<span class="p">(</span>mu <span class="o">=</span> runif<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">4</span><span class="p">,</span> <span class="m">10</span><span class="p">),</span> </span><span class="line"> sigma <span class="o">=</span> runif<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">6</span><span class="p">))</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line">jump <span class="o"><-</span> <span class="kr">function</span><span class="p">(</span>x<span class="p">,</span> dist <span class="o">=</span> <span class="m">.1</span><span class="p">){</span> <span class="c1"># must be symmetric</span> </span><span class="line"> x <span class="o">+</span> rnorm<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">0</span><span class="p">,</span> dist<span class="p">)</span> </span><span class="line"><span class="p">}</span> </span><span class="line"> </span><span class="line">iter <span class="o">=</span> <span class="m">10000</span> </span><span class="line">chains <span class="o"><-</span> <span class="m">3</span> </span><span class="line">posterior <span class="o"><-</span> array<span class="p">(</span>dim <span class="o">=</span> c<span class="p">(</span>chains<span class="p">,</span> <span class="m">2</span><span class="p">,</span> iter<span class="p">))</span> </span><span class="line">accepted <span class="o"><-</span> array<span class="p">(</span>dim<span class="o">=</span>c<span class="p">(</span>chains<span class="p">,</span> iter <span class="o">-</span> <span class="m">1</span><span class="p">))</span> </span><span class="line"> </span><span class="line"><span class="kr">for</span> <span class="p">(</span>c <span class="kr">in</span> <span class="m">1</span><span class="o">:</span>chains<span class="p">){</span> </span><span class="line"> theta.post <span class="o"><-</span> array<span class="p">(</span>dim<span class="o">=</span>c<span class="p">(</span><span class="m">2</span><span class="p">,</span> iter<span class="p">))</span> </span><span class="line"> inits <span class="o"><-</span> geninits<span class="p">()</span> </span><span class="line"> theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> inits<span class="o">$</span>mu </span><span class="line"> theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> inits<span class="o">$</span>sigma </span><span class="line"> <span class="kr">for</span> <span class="p">(</span>t <span class="kr">in</span> <span class="m">2</span><span class="o">:</span>iter<span class="p">){</span> </span><span class="line"> <span class="c1"># theta_star = proposed next values for parameters</span> </span><span class="line"> theta_star <span class="o"><-</span> c<span class="p">(</span>jump<span class="p">(</span>theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> t<span class="m">-1</span><span class="p">]),</span> jump<span class="p">(</span>theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> t<span class="m">-1</span><span class="p">]))</span> </span><span class="line"> pstar <span class="o"><-</span> post<span class="p">(</span>x<span class="p">,</span> mu <span class="o">=</span> theta_star<span class="p">[</span><span class="m">1</span><span class="p">],</span> sigma <span class="o">=</span> theta_star<span class="p">[</span><span class="m">2</span><span class="p">])</span> </span><span class="line"> pprev <span class="o"><-</span> post<span class="p">(</span>x<span class="p">,</span> mu <span class="o">=</span> theta.post<span class="p">[</span><span class="m">1</span><span class="p">,</span> t<span class="m">-1</span><span class="p">],</span> sigma <span class="o">=</span> theta.post<span class="p">[</span><span class="m">2</span><span class="p">,</span> t<span class="m">-1</span><span class="p">])</span> </span><span class="line"> lr <span class="o"><-</span> pstar <span class="o">-</span> pprev </span><span class="line"> r <span class="o"><-</span> exp<span class="p">(</span>lr<span class="p">)</span> </span><span class="line"> </span><span class="line"> <span class="c1"># theta_star is accepted if posterior density is higher w/ theta_star</span> </span><span class="line"> <span class="c1"># if posterior density is not higher, it is accepted with probability r</span> </span><span class="line"> <span class="c1"># else theta does not change from time t-1 to t</span> </span><span class="line"> accept <span class="o"><-</span> rbinom<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> prob <span class="o">=</span> min<span class="p">(</span>r<span class="p">,</span> <span class="m">1</span><span class="p">))</span> </span><span class="line"> accepted<span class="p">[</span>c<span class="p">,</span> t <span class="o">-</span> <span class="m">1</span><span class="p">]</span> <span class="o"><-</span> accept </span><span class="line"> <span class="kr">if</span> <span class="p">(</span>accept <span class="o">==</span> <span class="m">1</span><span class="p">){</span> </span><span class="line"> theta.post<span class="p">[,</span> t<span class="p">]</span> <span class="o"><-</span> theta_star </span><span class="line"> <span class="p">}</span> <span class="kr">else</span> <span class="p">{</span> </span><span class="line"> theta.post<span class="p">[,</span> t<span class="p">]</span> <span class="o"><-</span> theta.post<span class="p">[,</span> t<span class="m">-1</span><span class="p">]</span> </span><span class="line"> <span class="p">}</span> </span><span class="line"> <span class="p">}</span> </span><span class="line"> posterior<span class="p">[</span>c<span class="p">,</span> <span class="p">,</span> <span class="p">]</span> <span class="o"><-</span> theta.post </span><span class="line"><span class="p">}</span> </span> |
Then, to visualize the evolution of the Markov chains, we can make plots of the chains in 2-parameter space, along with the posterior density at different iterations, joining these plots together using ImageMagick (in the terminal) to create an animated .gif:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
<span class="line">require<span class="p">(</span>sm<span class="p">)</span> </span><span class="line">seq1 <span class="o"><-</span> seq<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">300</span><span class="p">,</span> by<span class="o">=</span><span class="m">5</span><span class="p">)</span> </span><span class="line">seq2 <span class="o"><-</span> seq<span class="p">(</span><span class="m">300</span><span class="p">,</span> <span class="m">500</span><span class="p">,</span> by<span class="o">=</span><span class="m">10</span><span class="p">)</span> </span><span class="line">seq3 <span class="o"><-</span> seq<span class="p">(</span><span class="m">500</span><span class="p">,</span> iter<span class="p">,</span> by<span class="o">=</span><span class="m">300</span><span class="p">)</span> </span><span class="line">sequence <span class="o"><-</span> c<span class="p">(</span>seq1<span class="p">,</span> seq2<span class="p">,</span> seq3<span class="p">)</span> </span><span class="line">length<span class="p">(</span>sequence<span class="p">)</span> </span><span class="line"> </span><span class="line">xlims <span class="o"><-</span> c<span class="p">(</span><span class="m">4</span><span class="p">,</span> <span class="m">10</span><span class="p">)</span> </span><span class="line">ylims <span class="o"><-</span> c<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">6</span><span class="p">)</span> </span><span class="line"> </span><span class="line">dir.create<span class="p">(</span><span class="s">"metropolis_ex"</span><span class="p">)</span> </span><span class="line">setwd<span class="p">(</span><span class="s">"metropolis_ex"</span><span class="p">)</span> </span><span class="line"> </span><span class="line">png<span class="p">(</span>file <span class="o">=</span> <span class="s">"metrop%03d.png"</span><span class="p">,</span> width<span class="o">=</span><span class="m">700</span><span class="p">,</span> height<span class="o">=</span><span class="m">350</span><span class="p">)</span> </span><span class="line"> <span class="kr">for</span> <span class="p">(</span>i <span class="kr">in</span> sequence<span class="p">){</span> </span><span class="line"> par<span class="p">(</span>mfrow<span class="o">=</span>c<span class="p">(</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">))</span> </span><span class="line"> plot<span class="p">(</span>posterior<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">1</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> </span><span class="line"> type<span class="o">=</span><span class="s">"l"</span><span class="p">,</span> xlim<span class="o">=</span>xlims<span class="p">,</span> ylim<span class="o">=</span>ylims<span class="p">,</span> col<span class="o">=</span><span class="s">"blue"</span><span class="p">,</span> </span><span class="line"> xlab<span class="o">=</span><span class="s">"mu"</span><span class="p">,</span> ylab<span class="o">=</span><span class="s">"sigma"</span><span class="p">,</span> main<span class="o">=</span><span class="s">"Markov chains"</span><span class="p">)</span> </span><span class="line"> lines<span class="p">(</span>posterior<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">2</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> </span><span class="line"> col<span class="o">=</span><span class="s">"purple"</span><span class="p">)</span> </span><span class="line"> lines<span class="p">(</span>posterior<span class="p">[</span><span class="m">3</span><span class="p">,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> posterior<span class="p">[</span><span class="m">3</span><span class="p">,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">],</span> </span><span class="line"> col<span class="o">=</span><span class="s">"red"</span><span class="p">)</span> </span><span class="line"> text<span class="p">(</span>x<span class="o">=</span><span class="m">7</span><span class="p">,</span> y<span class="o">=</span><span class="m">1.2</span><span class="p">,</span> paste<span class="p">(</span><span class="s">"Iteration "</span><span class="p">,</span> i<span class="p">),</span> cex<span class="o">=</span><span class="m">1.5</span><span class="p">)</span> </span><span class="line"> sm.density<span class="p">(</span>x<span class="o">=</span>cbind<span class="p">(</span>c<span class="p">(</span>posterior<span class="p">[,</span> <span class="m">1</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">]),</span> c<span class="p">(</span>posterior<span class="p">[,</span> <span class="m">2</span><span class="p">,</span> <span class="m">1</span><span class="o">:</span>i<span class="p">])),</span> </span><span class="line"> xlab<span class="o">=</span><span class="s">"mu"</span><span class="p">,</span> ylab<span class="o">=</span><span class="s">"sigma"</span><span class="p">,</span> </span><span class="line"> zlab<span class="o">=</span><span class="s">""</span><span class="p">,</span> zlim<span class="o">=</span>c<span class="p">(</span><span class="m">0</span><span class="p">,</span> <span class="m">.7</span><span class="p">),</span> </span><span class="line"> xlim<span class="o">=</span>xlims<span class="p">,</span> ylim<span class="o">=</span>ylims<span class="p">,</span> col<span class="o">=</span><span class="s">"white"</span><span class="p">)</span> </span><span class="line"> title<span class="p">(</span><span class="s">"Posterior density"</span><span class="p">)</span> </span><span class="line"> <span class="p">}</span> </span><span class="line">dev.off<span class="p">()</span> </span><span class="line">system<span class="p">(</span><span class="s">"convert -delay 15 *.png metrop.gif"</span><span class="p">)</span> </span><span class="line">file.remove<span class="p">(</span>list.files<span class="p">(</span>pattern<span class="o">=</span><span class="s">".png"</span><span class="p">))</span> </span> |
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.