pyro.contrib.funsor, a new backend for Pyro - Building inference algorithms (Part 2)

[1]:
from collections import OrderedDict
import functools

import torch
from torch.distributions import constraints

import funsor

from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.ops.indexing import Vindex
from pyro.poutine.messenger import Messenger

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

Introduction

In part 1 of this tutorial, we were introduced to the new pyro.contrib.funsor backend for Pyro.

Here we’ll look at how to use the components in pyro.contrib.funsor to implement a variable elimination inference algorithm from scratch. This tutorial assumes readers are familiar with enumeration-based inference algorithms in Pyro. For background and motivation, readers should consult the enumeration tutorial.

As before, we’ll use pyroapi so that we can write our model with standard Pyro syntax.

[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import infer, handlers, ops, optim, pyro
from pyroapi import distributions as dist

We will be working with the following model throughout. It is a discrete-state continuous-observation hidden Markov model with learnable transition and emission distributions that depend on a global random variable.

[3]:
data = [torch.tensor(1.)] * 10

def model(data, verbose):

    p = pyro.param("probs", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs_mean = pyro.param("locs_mean", lambda: torch.ones((3,)))
    locs = pyro.sample("locs", dist.Normal(locs_mean, 1.).to_event(1))
    if verbose:
        print("locs.shape = {}".format(locs.shape))

    x = 0
    for i in pyro.markov(range(len(data))):
        x = pyro.sample("x{}".format(i), dist.Categorical(p[x]), infer={"enumerate": "parallel"})
        if verbose:
            print("x{}.shape = ".format(i), x.shape)
        pyro.sample("y{}".format(i), dist.Normal(Vindex(locs)[..., x], 1.), obs=data[i])

We can run model under the default Pyro backend and the new contrib.funsor backend with pyroapi:

[4]:
# default backend: "pyro"
with pyroapi.pyro_backend("pyro"):
    model(data, verbose=True)

# new backend: "contrib.funsor"
with pyroapi.pyro_backend("contrib.funsor"):
    model(data, verbose=True)
locs.shape = torch.Size([3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])
locs.shape = torch.Size([3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])

Enumerating discrete variables

Our first step is to implement an effect handler that performs parallel enumeration of discrete latent variables. Here we will implement a stripped-down version of pyro.poutine.enum, the effect handler behind Pyro’s most powerful general-purpose inference algorithms pyro.infer.TraceEnum_ELBO and pyro.infer.mcmc.HMC.

We’ll do that by constructing a funsor.Tensor representing the support of each discrete latent variable and using the new pyro.to_data primitive from part 1 to convert it to a torch.Tensor with the appropriate shape.

[5]:
from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger

class EnumMessenger(NamedMessenger):

    @pyroapi.pyro_backend("contrib.funsor")  # necessary since we invoke pyro.to_data and pyro.to_funsor
    def _pyro_sample(self, msg):
        if msg["done"] or msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
            return

        # We first compute a raw value using the standard enumerate_support method.
        # enumerate_support returns a value of shape:
        #     (support_size,) + (1,) * len(msg["fn"].batch_shape).
        raw_value = msg["fn"].enumerate_support(expand=False)

        # Next we'll use pyro.to_funsor to indicate that this dimension is fresh.
        # This is guaranteed because we use msg['name'], the name of this pyro.sample site,
        # as the name for this positional dimension, and sample site names must be unique.
        funsor_value = pyro.to_funsor(
            raw_value,
            output=funsor.Bint[raw_value.shape[0]],
            dim_to_name={-raw_value.dim(): msg["name"]},
        )

        # Finally, we convert the value back to a PyTorch tensor with to_data,
        # which has the effect of reshaping and possibly permuting dimensions of raw_value.
        # Applying to_funsor and to_data in this way guarantees that
        # each enumerated random variable gets a unique fresh positional dimension
        # and that we can convert the model's log-probability tensors to funsor.Tensors
        # in a globally consistent manner.
        msg["value"] = pyro.to_data(funsor_value)
        msg["done"] = True

Because this is an introductory tutorial, this implementation of EnumMessenger works directly with the site’s PyTorch distribution since users familiar with PyTorch and Pyro may find it easier to understand. However, when using contrib.funsor to implement an inference algorithm in a more realistic setting, it is usually preferable to do as much computation as possible on funsors, as this tends to simplify complex indexing, broadcasting or shape manipulation logic.

For example, in EnumMessenger, we might instead call pyro.to_funsor on msg["fn"]:

funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)(value=msg["name"])
# enumerate_support defined whenever isinstance(funsor_dist, funsor.distribution.Distribution)
funsor_value = funsor_dist.enumerate_support(expand=False)
raw_value = pyro.to_data(funsor_value)

Most of the more complete inference algorithms implemented in pyro.contrib.funsor follow this pattern, and we will see an example later in this tutorial. Before we continue, let’s see what effect EnumMessenger has on the shapes of random variables in our model:

[6]:
with pyroapi.pyro_backend("contrib.funsor"), \
        EnumMessenger():
    model(data, True)
locs.shape = torch.Size([3])
x0.shape =  torch.Size([3, 1, 1, 1, 1])
x1.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x2.shape =  torch.Size([3, 1, 1, 1, 1])
x3.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x4.shape =  torch.Size([3, 1, 1, 1, 1])
x5.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x6.shape =  torch.Size([3, 1, 1, 1, 1])
x7.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x8.shape =  torch.Size([3, 1, 1, 1, 1])
x9.shape =  torch.Size([3, 1, 1, 1, 1, 1])

Vectorizing a model across multiple samples

Next, since our priors over global variables are continuous and cannot be enumerated exactly, we will implement an effect handler that uses a global dimension to draw multiple samples in parallel from the model. Our implementation will allocate a new particle dimension using pyro.to_data as in EnumMessenger above, but unlike the enumeration dimensions, we want the particle dimension to be shared across all sample sites, so we will mark it as a DimType.GLOBAL dimension when invoking pyro.to_funsor.

Recall that in part 1 we saw that DimType.GLOBAL dimensions must be deallocated manually or they will persist until the final effect handler has exited. This low-level detail is taken care of automatically by the GlobalNameMessenger handler provided in pyro.contrib.funsor as a base class for any effect handlers that allocate global dimensions. Our vectorization effect handler will inherit from this class.

[7]:
from pyro.contrib.funsor.handlers.named_messenger import GlobalNamedMessenger
from pyro.contrib.funsor.handlers.runtime import DimRequest, DimType

class VectorizeMessenger(GlobalNamedMessenger):

    def __init__(self, size, name="_PARTICLES"):
        super().__init__()
        self.name = name
        self.size = size

    @pyroapi.pyro_backend("contrib.funsor")
    def _pyro_sample(self, msg):
        if msg["is_observed"] or msg["done"] or msg["infer"].get("enumerate") == "parallel":
            return

        # we'll first draw a raw batch of samples similarly to EnumMessenger.
        # However, since we are drawing a single batch from the joint distribution,
        # we don't need to take multiple samples if the site is already batched.
        if self.name in pyro.to_funsor(msg["fn"], funsor.Real).inputs:
            raw_value = msg["fn"].rsample()
        else:
            raw_value = msg["fn"].rsample(sample_shape=(self.size,))

        # As before, we'll use pyro.to_funsor to register the new dimension.
        # This time, we indicate that the particle dimension should be treated as a global dimension.
        fresh_dim = len(msg["fn"].event_shape) - raw_value.dim()
        funsor_value = pyro.to_funsor(
            raw_value,
            output=funsor.Reals[tuple(msg["fn"].event_shape)],
            dim_to_name={fresh_dim: DimRequest(value=self.name, dim_type=DimType.GLOBAL)},
        )

        # finally, convert the sample to a PyTorch tensor using to_data as before
        msg["value"] = pyro.to_data(funsor_value)
        msg["done"] = True

Let’s see what effect VectorizeMessenger has on the shapes of the values in model:

[8]:
with pyroapi.pyro_backend("contrib.funsor"), \
        VectorizeMessenger(size=10):
    model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])

And now in combination with EnumMessenger:

[9]:
with pyroapi.pyro_backend("contrib.funsor"), \
        VectorizeMessenger(size=10), EnumMessenger():
    model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x1.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x2.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x3.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x4.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x5.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x6.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x7.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x8.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x9.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])

Computing an ELBO with variable elimination

Now that we have tools for enumerating discrete variables and drawing batches of samples, we can use those to compute quantities of interest for inference algorithms.

Most inference algorithms in Pyro work with pyro.poutine.Traces, custom data structures that contain parameters and sample site distributions and values and all of the associated metadata needed for inference computations. Our third effect handler LogJointMessenger departs from this design pattern, eliminating a tremendous amount of boilerplate in the process. It will automatically build up a lazy Funsor expression for the logarithm of the joint probability density of a model; when working with Traces, this process must be triggered manually by calling Trace.compute_log_probs() and eagerly computing an objective from the resulting individual log-probability tensors in the trace.

In our implementation of LogJointMessenger, unlike the previous two effect handlers, we will call pyro.to_funsor on both the sample value and the distribution to show how nearly all inference operations including log-probability density evaluation can be performed on funsor.Funsors directly.

[10]:
class LogJointMessenger(Messenger):

    def __enter__(self):
        self.log_joint = funsor.Number(0.)
        return super().__enter__()

    @pyroapi.pyro_backend("contrib.funsor")
    def _pyro_post_sample(self, msg):

        # for Monte Carlo-sampled variables, we don't include a log-density term:
        if not msg["is_observed"] and not msg["infer"].get("enumerate"):
            return

        with funsor.interpreter.interpretation(funsor.terms.lazy):
            funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)
            funsor_value = pyro.to_funsor(msg["value"], output=funsor_dist.inputs["value"])
            self.log_joint += funsor_dist(value=funsor_value)

And finally the actual loss function, which applies our three effect handlers to compute an expression for the log-density, marginalizes over discrete variables with funsor.ops.logaddexp, averages over Monte Carlo samples with funsor.ops.add, and evaluates the final lazy expression using Funsor’s optimize interpretation for variable elimination.

Note that log_z exactly collapses the model’s local discrete latent variables but is an ELBO wrt any continuous latent variables, and is thus equivalent to a simple version of TraceEnum_ELBO with an empty guide.

[11]:
@pyroapi.pyro_backend("contrib.funsor")
def log_z(model, model_args, size=10):
    with LogJointMessenger() as tr, \
            VectorizeMessenger(size=size) as v, \
            EnumMessenger():
        model(*model_args)

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        prod_vars = frozenset({v.name})
        sum_vars = frozenset(tr.log_joint.inputs) - prod_vars

        # sum over the discrete random variables we enumerated
        expr = tr.log_joint.reduce(funsor.ops.logaddexp, sum_vars)

        # average over the sample dimension
        expr = expr.reduce(funsor.ops.add, prod_vars) - funsor.Number(float(size))

    return pyro.to_data(funsor.optimizer.apply_optimizer(expr))

Putting it all together

Finally, with all this machinery implemented, we can compute stochastic gradients wrt the ELBO.

[12]:
with pyroapi.pyro_backend("contrib.funsor"):
    model(data, verbose=False)  # initialize parameters
    params = [pyro.param("probs").unconstrained(), pyro.param("locs_mean").unconstrained()]

optimizer = torch.optim.Adam(params, lr=0.1)
for step in range(5):
    optimizer.zero_grad()
    log_marginal = log_z(model, (data, False))
    (-log_marginal).backward()
    optimizer.step()
    print(log_marginal)
tensor(-133.6274, grad_fn=<AddBackward0>)
tensor(-129.2379, grad_fn=<AddBackward0>)
tensor(-125.9609, grad_fn=<AddBackward0>)
tensor(-123.7484, grad_fn=<AddBackward0>)
tensor(-122.3034, grad_fn=<AddBackward0>)
[ ]: