Interactive posterior predictives checks

This notebook demonstrates how to interactively examine model priors using ipywidgets.

⚠️ This notebook is intended to be run interactively. Please run locally or Open in Colab.

The first step in Bayesian workflow is to create a model. The second step is to check prior samples from the model. This notebook shows how to interactively check prior samples and tune parameters of the top level prior distribution while visualizing model outputs.

Summary

  • Wrap your model in a plotting function.

  • Use ipywidgets.interact() to create sliders for each parameter of your prior.

  • For expensive models, use a Resampler.

[1]:
!pip install -q pyro-ppl  # for colab
[2]:
import os
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.resampler import Resampler

assert pyro.__version__.startswith('1.9.1')
smoke_test = ('CI' in os.environ)  # for CI testing only
[3]:
def model(T: int = 1000, data=None):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(0, 1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(0, 1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))  # measurement noise

    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[4]:
def plot_trajectory(df=1.0, p_scale=1.0, m_scale=1.0):
    pyro.set_rng_seed(12345)
    data = {
        "df": torch.as_tensor(df),
        "p_scale": torch.as_tensor(p_scale),
        "m_scale": torch.as_tensor(m_scale),
    }
    trajectory = poutine.condition(model, data)()
    plt.figure(figsize=(8, 4)).patch.set_color("white")
    plt.plot(trajectory)
    plt.xlabel("time")
    plt.ylabel("obs")

Now we can examine what model trajectories look like for particular values of top level latent variables.

[5]:
interact(
    plot_trajectory,
    df=FloatSlider(value=1.0, min=0.01, max=10.0),
    p_scale=FloatSlider(value=0.1, min=0.01, max=1.0),
    m_scale=FloatSlider(value=1.0, min=0.01, max=10.0),
);

But to tune the parameters of our priors, we’d like to look at an ensemble of trajectories each of whose top-level parameters is sampled from the current prior. Let’s rewrite our model so we can input the prior parameters.

[6]:
def model2(T: int = 1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(df0, df1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(p0, p1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(m0, m1))  # measurement noise

    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[7]:
def plot_trajectories(**kwargs):
    pyro.set_rng_seed(12345)
    with pyro.plate("trajectories", 20, dim=-2):
        trajectories = model2(**kwargs)
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")
[8]:
interact(
    plot_trajectories,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

Yikes! It looks like our initial priors generated very weird trajectories, but we can slide to find better priors. Try increasing df0.

Resampler

For more expensive simulations, sampling may be too slow to interactively generate samples at each change. As a computational trick we can draw many samples once from a diffuse distribution, then resample them from a modified distribution – provided we importance sample or resample. Pyro provides an importance Resampler to aid in interactively visualizing expensive models.

We’ll start with our original model and create a way to make parametrized partial models with given priors. These partial models are just the top half our our model, the top level parameters.

[9]:
def make_partial_model(df0, df1, p0, p1, m0, m1):
    def partial_model():
        # Sample parameters from the prior.
        pyro.sample("df", dist.LogNormal(df0, df1))
        pyro.sample("p_scale", dist.LogNormal(p0, p1))  # process noise
        pyro.sample("m_scale", dist.LogNormal(m0, m1))  # measurement noise
    return partial_model

Next we’ll initialize the Resampler with a diffuse guide that covers most of our desired parameter space. This can be expensive in real simulations, so you might want to run it overnight.

[10]:
%%time
partial_guide = make_partial_model(0, 10, 0, 10, 0, 10)
resampler = Resampler(partial_guide, model, num_guide_samples=10000)
CPU times: user 940 ms, sys: 146 ms, total: 1.09 s
Wall time: 934 ms

The Resampler.sample() method takes a modified partial model.

[11]:
def plot_resampled(df0, df1, p0, p1, m0, m1):
    partial_model = make_partial_model(df0, df1, p0, p1, m0, m1)
    samples = resampler.sample(partial_model, num_samples=20)
    trajectories = samples["obs"]
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")
[12]:
interact(
    plot_resampled,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

After deciding on good prior parameters, we can then hard-code those into the model:

[13]:
def model(T: int = 1000, data=None):
    df = pyro.sample("df", dist.LogNormal(4, 1))  # <-- changed 0 to 4
    p_scale = pyro.sample("p_scale", dist.LogNormal(1, 1))  # <-- changed 0 to 1
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))

    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[ ]: