Parallel tempering example¶
Here we're going to take a simple (one dimensional) multimodal distribution of the form $$ \pi_\beta \propto e^{-\beta f} $$ for some (non-convex) function $f$. Given that the state space is small, one can easily sample from $\pi_\beta$ using several algorithms (e.g. rejection or inversion). The point here, however, is to illustrate the use of parallel tempering in a model scenario. (If you want a situation where even rejection / inversion will fail, just take a similar example in higher dimensions instead.)
%matplotlib inline
import numpy as np, matplotlib as mpl, scipy as sp, networkx as nx
from matplotlib import pylab, mlab, pyplot as plt
from matplotlib.pylab import plot, scatter, contour, xlabel, ylabel, title, legend
from matplotlib.animation import FuncAnimation
from tqdm.notebook import tqdm, trange
from numpy import sqrt, pi, exp, log, floor, ceil, sin, cos
from numpy import linspace, arange, empty, zeros, ones, empty_like, zeros_like, ones_like
from numpy.linalg import norm
from scipy.stats import wasserstein_distance as W
rng = np.random.default_rng()
plt.rcParams.update({
'image.cmap': 'coolwarm',
'animation.html': 'jshtml',
'animation.embed_limit': 40, # 40 mb
})
def figsize(x=None, y=None):
mpl.rcParams['figure.figsize'] = [x, y] if x is not None else mpl.rcParamsDefault['figure.figsize']
Chose our state space $\mathcal X$ to be $N$ linearly spaced points in the interval $[-1, 1]$, and choose $f$ to be an asymmetric function with two minima.
N = 101 # Number of points
# The computational cost scales like N^2; On my system this runs in a few minutes for N=50.
# You can still illustrate the main point using N=21 or so. So if this runs really slowly, reduce N.
xx = linspace(-1, 1, num=N)
dx = xx[1] - xx[0]
def f(x):
return -log(exp(-(x - .5)**2 / 2 / .2**2) + 1.3*exp(-(x+.5)**2 / 2 / .3**2))
plot(xx, f(xx), label='f')
legend()
# plt.savefig('figures/mm-f.svg')
<matplotlib.legend.Legend at 0x75bd48bf16a0>
In such a small state space, the normalized densities can easily be computed. Let'svisualize $\pi_\beta$ for a few different values of $\beta$.
def compute_π(β):
Z = np.trapezoid(exp(-β*f(xx)), xx)
return exp(-β*f(xx))/Z
for β in linspace(0, 10, num=7):
π = compute_π(β)
plot(xx, π, label=rf'$\beta={β:.2f}$')
legend()
# plt.savefig('figures/mm-pi-beta.svg')
<matplotlib.legend.Legend at 0x75bd48a66990>
Question¶
Implement the functions metropolis_hastings and parallel_tempering with the following signatures:
def metropolis_hastings(X, β):
'''Modifies X to the new state of one step of the Metropolis Hastings Chain.
X: Initial distribution array (size n_reals)
β: temperature of the stationary distribution of the chain
'''
n_reals = X0.shape[-1]
# Implement the algorithm here.
# Modify X in place; don't return anything.
def parallel_tempering(X0, ββ, n_iters):
'''Run the parallel tempering chain
ββ: array of K temperature values
X0: Initial distribution (shape=(K, n_reals))
n_iters: number of iterations.
Returns the final state of the chain X
'''
n_reals = X0.shape[-1]
K = len(ββ)
X = X0.copy()
for _ in trange(n_iters):
# implement the algorithm here
return X
Once you implement these functions, the rest of this notebook should run and show how well the algorithm performs.
Note: If you don't vectorize your code, it will run extremely slowly! My code runs in a few minutes on an i5 from 2020. If you're finding the notebook runs too slowly for you, try reducing the parameters so that N=21, n_reals=1000,
# Helper functions for plots
bins = np.append(xx, xx[-1]+dx)-dx/2
def TV(X, π):
'''TV distance between the distribution of X and π'''
hist, _ = np.histogram(X, bins, density=True)
return sum(abs(hist-π))*dx/2
def plot_hist(Y, π, β, **kwargs):
hist_args = dict(density=True, bins=np.append(xx, xx[-1]+dx)-dx/2)
hist_args.update(**kwargs)
_ = plt.hist(Y, **hist_args)
plot(xx, π, label=rf'$\pi_\beta (\beta={β:.2f})$')
legend()
# Parameters
n_reals = 10000 # Increase this to get better histogram plots.
n_iters = 2*N**2 # Should be larger than the mixing time of the tempered chain
β0 = 1 # Temperature at which Metropolis Hastings mixes fast
β1 = 12 # Final temperature of target distribution
# Choose the worst possible starting point to test the algorithm.
# This should be when we put all mass at the local minimum at x=0.5
x0 = xx[3*N//4]
X0 = np.full(n_reals, x0)
# Run Metropolis Hastings for a small β (should mix fast)
β = β0
π = compute_π(β)
d = empty(n_iters)
X = X0.copy()
d[0] = TV(X, π)
for n in trange(n_iters):
metropolis_hastings(X, β)
d[n] = TV(X, π)
0%| | 0/20402 [00:00<?, ?it/s]
figsize(12.8, 4.8)
plt.subplot(1, 2, 1)
plot_hist(X, π, β, label=f'$X_N$ ($N = {n_iters}$)')
title(f'Distribution of $X_N$ (β={β})')
plt.subplot(1, 2, 2)
plot(d)
title(r'$\|\operatorname{dist}(X_N) - \pi\|_{\text{TV}}$')
xlabel('# Iterations')
figsize()
# Repeat the same for a large β (should mix very slowly)
β = β1
π = compute_π(β)
d = empty(n_iters)
X = X0.copy()
d[0] = TV(X, π)
for n in trange(n_iters):
metropolis_hastings(X, β)
d[n] = TV(X, π)
0%| | 0/20402 [00:00<?, ?it/s]
figsize(12.8, 4.8)
plt.subplot(1, 2, 1)
plot_hist(X, π, β, label=f'$X_N$ ($N = {n_iters}$)')
title(f'Distribution of $X_N$ (β={β})')
plt.subplot(1, 2, 2)
plot(d)
title(r'$\|\operatorname{dist}(X_N) - \pi\|_{\text{TV}}$')
xlabel('# Iterations')
figsize()
# Now run parallel tempering.
K = 8 # Number of levels
X0 = np.full((K, n_reals), x0)
ββ = linspace(β0, β1, num=K)
X = parallel_tempering(X0, ββ, n_iters)
0%| | 0/20402 [00:00<?, ?it/s]
i = -1
β = ββ[i]
π = compute_π(β)
d = TV(X, π)
plot_hist(X[i], π, β, label=f'$X_N$ ($N = {n_iters}$)')
title(f'β={β:.2f}, TV={d:.2f}')
# If everything worked, we should get good agreement with the target distribution here.
Text(0.5, 1.0, 'β=12.00, TV=0.31')
For a fair comparison, each iteration of parallel tempering runs $K$ chains, so
has K times the computational cost. So we really should look for how well Metropolis Hastings does after K*n_iters iterations. Let's increase this by a generous factor and see how well it does.
β = ββ[-1]
π = compute_π(β)
d = empty(4*K*n_iters)
X = X0[0].copy()
d[0] = TV(X, π)
for n in trange(len(d)):
metropolis_hastings(X, β)
d[n] = TV(X, π)
0%| | 0/652864 [00:00<?, ?it/s]
figsize(12.8, 4.8)
plt.subplot(1, 2, 1)
plot_hist(X, π, β, label=f'$X_N$ ($N = {n_iters}$)')
title(f'Distribution of $X_N$ (β={β})')
plt.subplot(1, 2, 2)
plot(d)
title(r'$\|\operatorname{dist}(X_N) - \pi\|_{\text{TV}}$')
xlabel('# Iterations')
figsize()
Clearly this shows that a vanilla Metropolis--Hastings barley even got close to the target distribution, in spite of running for much longer time. As $\beta$ grows, the number of iterations required for a vanilla Metropolis--Hastings grows exponentially with the inverse temperature $\beta$. If you'd like to see how things vary, experiment around by varying $\beta_1$ in this notebook.