pyro.contrib.funsor
, a new backend for Pyro - Building inference algorithms (Part 2)¶
[1]:
from collections import OrderedDict
import functools
import torch
from torch.distributions import constraints
import funsor
from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.ops.indexing import Vindex
from pyro.poutine.messenger import Messenger
funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)
Introduction¶
In part 1 of this tutorial, we were introduced to the new pyro.contrib.funsor
backend for Pyro.
Here we’ll look at how to use the components in pyro.contrib.funsor
to implement a variable elimination inference algorithm from scratch. This tutorial assumes readers are familiar with enumeration-based inference algorithms in Pyro. For background and motivation, readers should consult the enumeration tutorial.
As before, we’ll use pyroapi
so that we can write our model with standard Pyro syntax.
[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import infer, handlers, ops, optim, pyro
from pyroapi import distributions as dist
We will be working with the following model throughout. It is a discrete-state continuous-observation hidden Markov model with learnable transition and emission distributions that depend on a global random variable.
[3]:
data = [torch.tensor(1.)] * 10
def model(data, verbose):
p = pyro.param("probs", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
locs_mean = pyro.param("locs_mean", lambda: torch.ones((3,)))
locs = pyro.sample("locs", dist.Normal(locs_mean, 1.).to_event(1))
if verbose:
print("locs.shape = {}".format(locs.shape))
x = 0
for i in pyro.markov(range(len(data))):
x = pyro.sample("x{}".format(i), dist.Categorical(p[x]), infer={"enumerate": "parallel"})
if verbose:
print("x{}.shape = ".format(i), x.shape)
pyro.sample("y{}".format(i), dist.Normal(Vindex(locs)[..., x], 1.), obs=data[i])
We can run model
under the default Pyro backend and the new contrib.funsor
backend with pyroapi
:
[4]:
# default backend: "pyro"
with pyroapi.pyro_backend("pyro"):
model(data, verbose=True)
# new backend: "contrib.funsor"
with pyroapi.pyro_backend("contrib.funsor"):
model(data, verbose=True)
locs.shape = torch.Size([3])
x0.shape = torch.Size([])
x1.shape = torch.Size([])
x2.shape = torch.Size([])
x3.shape = torch.Size([])
x4.shape = torch.Size([])
x5.shape = torch.Size([])
x6.shape = torch.Size([])
x7.shape = torch.Size([])
x8.shape = torch.Size([])
x9.shape = torch.Size([])
locs.shape = torch.Size([3])
x0.shape = torch.Size([])
x1.shape = torch.Size([])
x2.shape = torch.Size([])
x3.shape = torch.Size([])
x4.shape = torch.Size([])
x5.shape = torch.Size([])
x6.shape = torch.Size([])
x7.shape = torch.Size([])
x8.shape = torch.Size([])
x9.shape = torch.Size([])
Enumerating discrete variables¶
Our first step is to implement an effect handler that performs parallel enumeration of discrete latent variables. Here we will implement a stripped-down version of pyro.poutine.enum
, the effect handler behind Pyro’s most powerful general-purpose inference algorithms pyro.infer.TraceEnum_ELBO
and pyro.infer.mcmc.HMC
.
We’ll do that by constructing a funsor.Tensor
representing the support of each discrete latent variable and using the new pyro.to_data
primitive from part 1 to convert it to a torch.Tensor
with the appropriate shape.
[5]:
from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger
class EnumMessenger(NamedMessenger):
@pyroapi.pyro_backend("contrib.funsor") # necessary since we invoke pyro.to_data and pyro.to_funsor
def _pyro_sample(self, msg):
if msg["done"] or msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
return
# We first compute a raw value using the standard enumerate_support method.
# enumerate_support returns a value of shape:
# (support_size,) + (1,) * len(msg["fn"].batch_shape).
raw_value = msg["fn"].enumerate_support(expand=False)
# Next we'll use pyro.to_funsor to indicate that this dimension is fresh.
# This is guaranteed because we use msg['name'], the name of this pyro.sample site,
# as the name for this positional dimension, and sample site names must be unique.
funsor_value = pyro.to_funsor(
raw_value,
output=funsor.Bint[raw_value.shape[0]],
dim_to_name={-raw_value.dim(): msg["name"]},
)
# Finally, we convert the value back to a PyTorch tensor with to_data,
# which has the effect of reshaping and possibly permuting dimensions of raw_value.
# Applying to_funsor and to_data in this way guarantees that
# each enumerated random variable gets a unique fresh positional dimension
# and that we can convert the model's log-probability tensors to funsor.Tensors
# in a globally consistent manner.
msg["value"] = pyro.to_data(funsor_value)
msg["done"] = True
Because this is an introductory tutorial, this implementation of EnumMessenger
works directly with the site’s PyTorch distribution since users familiar with PyTorch and Pyro may find it easier to understand. However, when using contrib.funsor
to implement an inference algorithm in a more realistic setting, it is usually preferable to do as much computation as possible on funsors, as this tends to simplify complex indexing, broadcasting or shape manipulation logic.
For example, in EnumMessenger
, we might instead call pyro.to_funsor
on msg["fn"]
:
funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)(value=msg["name"])
# enumerate_support defined whenever isinstance(funsor_dist, funsor.distribution.Distribution)
funsor_value = funsor_dist.enumerate_support(expand=False)
raw_value = pyro.to_data(funsor_value)
Most of the more complete inference algorithms implemented in pyro.contrib.funsor
follow this pattern, and we will see an example later in this tutorial. Before we continue, let’s see what effect EnumMessenger
has on the shapes of random variables in our model:
[6]:
with pyroapi.pyro_backend("contrib.funsor"), \
EnumMessenger():
model(data, True)
locs.shape = torch.Size([3])
x0.shape = torch.Size([3, 1, 1, 1, 1])
x1.shape = torch.Size([3, 1, 1, 1, 1, 1])
x2.shape = torch.Size([3, 1, 1, 1, 1])
x3.shape = torch.Size([3, 1, 1, 1, 1, 1])
x4.shape = torch.Size([3, 1, 1, 1, 1])
x5.shape = torch.Size([3, 1, 1, 1, 1, 1])
x6.shape = torch.Size([3, 1, 1, 1, 1])
x7.shape = torch.Size([3, 1, 1, 1, 1, 1])
x8.shape = torch.Size([3, 1, 1, 1, 1])
x9.shape = torch.Size([3, 1, 1, 1, 1, 1])
Vectorizing a model across multiple samples¶
Next, since our priors over global variables are continuous and cannot be enumerated exactly, we will implement an effect handler that uses a global dimension to draw multiple samples in parallel from the model. Our implementation will allocate a new particle dimension using pyro.to_data
as in EnumMessenger
above, but unlike the enumeration dimensions, we want the particle dimension to be shared across all sample sites, so we will mark it as a DimType.GLOBAL
dimension when invoking
pyro.to_funsor
.
Recall that in part 1 we saw that DimType.GLOBAL
dimensions must be deallocated manually or they will persist until the final effect handler has exited. This low-level detail is taken care of automatically by the GlobalNameMessenger
handler provided in pyro.contrib.funsor
as a base class for any effect handlers that allocate global dimensions. Our vectorization effect handler will inherit from this class.
[7]:
from pyro.contrib.funsor.handlers.named_messenger import GlobalNamedMessenger
from pyro.contrib.funsor.handlers.runtime import DimRequest, DimType
class VectorizeMessenger(GlobalNamedMessenger):
def __init__(self, size, name="_PARTICLES"):
super().__init__()
self.name = name
self.size = size
@pyroapi.pyro_backend("contrib.funsor")
def _pyro_sample(self, msg):
if msg["is_observed"] or msg["done"] or msg["infer"].get("enumerate") == "parallel":
return
# we'll first draw a raw batch of samples similarly to EnumMessenger.
# However, since we are drawing a single batch from the joint distribution,
# we don't need to take multiple samples if the site is already batched.
if self.name in pyro.to_funsor(msg["fn"], funsor.Real).inputs:
raw_value = msg["fn"].rsample()
else:
raw_value = msg["fn"].rsample(sample_shape=(self.size,))
# As before, we'll use pyro.to_funsor to register the new dimension.
# This time, we indicate that the particle dimension should be treated as a global dimension.
fresh_dim = len(msg["fn"].event_shape) - raw_value.dim()
funsor_value = pyro.to_funsor(
raw_value,
output=funsor.Reals[tuple(msg["fn"].event_shape)],
dim_to_name={fresh_dim: DimRequest(value=self.name, dim_type=DimType.GLOBAL)},
)
# finally, convert the sample to a PyTorch tensor using to_data as before
msg["value"] = pyro.to_data(funsor_value)
msg["done"] = True
Let’s see what effect VectorizeMessenger
has on the shapes of the values in model
:
[8]:
with pyroapi.pyro_backend("contrib.funsor"), \
VectorizeMessenger(size=10):
model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape = torch.Size([])
x1.shape = torch.Size([])
x2.shape = torch.Size([])
x3.shape = torch.Size([])
x4.shape = torch.Size([])
x5.shape = torch.Size([])
x6.shape = torch.Size([])
x7.shape = torch.Size([])
x8.shape = torch.Size([])
x9.shape = torch.Size([])
And now in combination with EnumMessenger
:
[9]:
with pyroapi.pyro_backend("contrib.funsor"), \
VectorizeMessenger(size=10), EnumMessenger():
model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape = torch.Size([3, 1, 1, 1, 1, 1])
x1.shape = torch.Size([3, 1, 1, 1, 1, 1, 1])
x2.shape = torch.Size([3, 1, 1, 1, 1, 1])
x3.shape = torch.Size([3, 1, 1, 1, 1, 1, 1])
x4.shape = torch.Size([3, 1, 1, 1, 1, 1])
x5.shape = torch.Size([3, 1, 1, 1, 1, 1, 1])
x6.shape = torch.Size([3, 1, 1, 1, 1, 1])
x7.shape = torch.Size([3, 1, 1, 1, 1, 1, 1])
x8.shape = torch.Size([3, 1, 1, 1, 1, 1])
x9.shape = torch.Size([3, 1, 1, 1, 1, 1, 1])
Computing an ELBO with variable elimination¶
Now that we have tools for enumerating discrete variables and drawing batches of samples, we can use those to compute quantities of interest for inference algorithms.
Most inference algorithms in Pyro work with pyro.poutine.Trace
s, custom data structures that contain parameters and sample site distributions and values and all of the associated metadata needed for inference computations. Our third effect handler LogJointMessenger
departs from this design pattern, eliminating a tremendous amount of boilerplate in the process. It will automatically build up a lazy Funsor expression for the logarithm of the joint probability density of a model; when
working with Trace
s, this process must be triggered manually by calling Trace.compute_log_probs()
and eagerly computing an objective from the resulting individual log-probability tensors in the trace.
In our implementation of LogJointMessenger
, unlike the previous two effect handlers, we will call pyro.to_funsor
on both the sample value and the distribution to show how nearly all inference operations including log-probability density evaluation can be performed on funsor.Funsor
s directly.
[10]:
class LogJointMessenger(Messenger):
def __enter__(self):
self.log_joint = funsor.Number(0.)
return super().__enter__()
@pyroapi.pyro_backend("contrib.funsor")
def _pyro_post_sample(self, msg):
# for Monte Carlo-sampled variables, we don't include a log-density term:
if not msg["is_observed"] and not msg["infer"].get("enumerate"):
return
with funsor.interpreter.interpretation(funsor.terms.lazy):
funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)
funsor_value = pyro.to_funsor(msg["value"], output=funsor_dist.inputs["value"])
self.log_joint += funsor_dist(value=funsor_value)
And finally the actual loss function, which applies our three effect handlers to compute an expression for the log-density, marginalizes over discrete variables with funsor.ops.logaddexp
, averages over Monte Carlo samples with funsor.ops.add
, and evaluates the final lazy expression using Funsor’s optimize
interpretation for variable elimination.
Note that log_z
exactly collapses the model’s local discrete latent variables but is an ELBO wrt any continuous latent variables, and is thus equivalent to a simple version of TraceEnum_ELBO
with an empty guide.
[11]:
@pyroapi.pyro_backend("contrib.funsor")
def log_z(model, model_args, size=10):
with LogJointMessenger() as tr, \
VectorizeMessenger(size=size) as v, \
EnumMessenger():
model(*model_args)
with funsor.interpreter.interpretation(funsor.terms.lazy):
prod_vars = frozenset({v.name})
sum_vars = frozenset(tr.log_joint.inputs) - prod_vars
# sum over the discrete random variables we enumerated
expr = tr.log_joint.reduce(funsor.ops.logaddexp, sum_vars)
# average over the sample dimension
expr = expr.reduce(funsor.ops.add, prod_vars) - funsor.Number(float(size))
return pyro.to_data(funsor.optimizer.apply_optimizer(expr))
Putting it all together¶
Finally, with all this machinery implemented, we can compute stochastic gradients wrt the ELBO.
[12]:
with pyroapi.pyro_backend("contrib.funsor"):
model(data, verbose=False) # initialize parameters
params = [pyro.param("probs").unconstrained(), pyro.param("locs_mean").unconstrained()]
optimizer = torch.optim.Adam(params, lr=0.1)
for step in range(5):
optimizer.zero_grad()
log_marginal = log_z(model, (data, False))
(-log_marginal).backward()
optimizer.step()
print(log_marginal)
tensor(-133.6274, grad_fn=<AddBackward0>)
tensor(-129.2379, grad_fn=<AddBackward0>)
tensor(-125.9609, grad_fn=<AddBackward0>)
tensor(-123.7484, grad_fn=<AddBackward0>)
tensor(-122.3034, grad_fn=<AddBackward0>)
[ ]: