Inference with Discrete Latent Variables

This tutorial describes Pyro’s enumeration strategy for discrete latent variable models. This tutorial assumes the reader is already familiar with the Tensor Shapes Tutorial.

Summary

  • Pyro implements automatic enumeration over discrete latent variables.
  • This strategy can be used alone or inside SVI (via TraceEnum_ELBO), HMC, or NUTS.
  • The standalone infer_discrete can generate samples or MAP estimates.
  • Annotate a sample site infer={"enumerate": "parallel"} to trigger enumeration.
  • If a sample site determines downstream structure, instead use {"enumerate": "sequential"}.
  • Write your models to allow arbitrarily deep batching on the left, e.g. use broadcasting.
  • Inference cost is exponential in treewidth, so try to write models with narrow treewidth.
  • If you have trouble, ask for help on forum.pyro.ai!

Table of contents

In [1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.contrib.autoguide import AutoDiagonalNormal

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('0.3.0')
pyro.enable_validation()
pyro.set_rng_seed(0)

Overview

Pyro’s enumeration strategy encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference (SVI) and monte carlo (HMC and NUTS). Enumeration can be used either as a stand-alone strategy via infer_discrete, or as a component of other strategies. Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides.

Mechanics of enumeration

The core idea of enumeration is to interpret discrete pyro.sample statements as full enumeration rather than random sampling. Other inference algorithms can then sum out the enumerated values. For example a sample statement might return a tensor of scalar shape under the standard “sample” interpretation (we’ll illustrate with trivial model and guide):

In [2]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('model z = {}'.format(z))

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('guide z = {}'.format(z))

elbo = Trace_ELBO()
elbo.loss(model, guide);
guide z = 4
model z = 4

However under the enumeration interpretation, the same sample site will return a fully enumerated set of values, based on its distribution’s .enumerate_support() method.

In [3]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));
guide z = tensor([0, 1, 2, 3, 4])
model z = tensor([0, 1, 2, 3, 4])

Note that we’ve used “parallel” enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of z. To support dynamic program structure, you can instead use “sequential” enumeration, which runs the entire model,guide pair once per sample value, but requires running the model multiple times.

In [4]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));
guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0

Parallel enumeration is cheaper but more complex than sequential enumeration, so we’ll focus the rest of this tutorial on the parallel variant. Note that both forms can be interleaved.

Multiple latent variables

We just saw that a single discrete sample site can be enumerated via nonstandard interpretation. A model with a single discrete latent variable is a mixture model. Models with multiple discrete latent variables can be more complex, including HMMs, CRFs, DBNs, and other structured models. In models with multiple discrete latent variables, Pyro enumerates each variable in a different tensor dimension (counting from the right; see Tensor Shapes Tutorial). This allows Pyro to determine the dependency graph among variables and then perform cheap exact inference using variable elimination algorithms.

To understand enumeration dimension allocation, consider the following model, where here we collapse variables out of the model, rather than enumerate them in the guide.

In [5]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print('model x.shape = {}'.format(x.shape))
    print('model y.shape = {}'.format(y.shape))
    print('model z.shape = {}'.format(z.shape))
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);
model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])

Examining discrete latent states

While enumeration in SVI allows fast learning of parameters like p above, it does not give access to predicted values of the discrete latent variables like x,y,z above. We can access these using a standalone infer_discrete handler. In this case the guide was trivial, so we can simply wrap the model in infer_discrete. We need to pass a first_available_dim argument to tell infer_discrete which dimensions are available for enumeration; this is related to the max_plate_nesting arg of TraceEnum_ELBO via

first_available_dim = -1 - max_plate_nesting
In [6]:
serving_model = infer_discrete(model, first_available_dim=-1)
x, y, z = serving_model()  # takes the same args as model(), here no args
print("x = {}".format(x))
print("y = {}".format(y))
print("z = {}".format(z))
model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])
model x.shape = torch.Size([])
model y.shape = torch.Size([])
model z.shape = torch.Size([])
x = 0
y = 2
z = 1

Notice that under the hood infer_discrete runs the model twice: first in forward-filter mode where sites are enumerated, then in replay-backward-sample model where sites are sampled. infer_discrete can also perform MAP inference by passing temperature=0. Note that while infer_discrete produces correct posterior samples, it does not currently produce correct logprobs, and should not be used in other gradient-based inference algorthms.

Plates and enumeration

Pyro plates express conditional independence among random variables. Pyro’s enumeration strategy can take advantage of plates to reduce the high cost (exponential in the size of the plate) of enumerating a cartesian product down to a low cost (linear in the size of the plate) of enumerating conditionally independent random variables in lock-step. This is especially important for e.g. minibatched data.

To illustrate, consider a gaussian mixture model with shared variance and different mean.

In [7]:
@config_enumerate
def model(data, num_components=3):
    print('Running model with {} data points'.format(len(data)))
    p = pyro.sample("p", dist.Dirichlet(0.5 * torch.ones(3)))
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))
    with pyro.plate("components", num_components):
        loc = pyro.sample("loc", dist.Normal(0, 10))
    with pyro.plate("data", len(data)):
        x = pyro.sample("x", dist.Categorical(p))
        print("x.shape = {}".format(x.shape))
        pyro.sample("obs", dist.Normal(loc[x], scale), obs=data)
        print("dist.Normal(loc[x], scale).batch_shape = {}".format(
            dist.Normal(loc[x], scale).batch_shape))

guide = AutoDiagonalNormal(poutine.block(model, hide=["x", "data"]))

data = torch.randn(10)

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, data);
Running model with 10 data points
x.shape = torch.Size([10])
dist.Normal(loc[x], scale).batch_shape = torch.Size([10])
Running model with 10 data points
x.shape = torch.Size([3, 1])
dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])

Observe that the model is run twice, first by the AutoDiagonalNormal to trace sample sites, and second by elbo to compute loss. In the first run, x has the standard interpretation of one sample per datum, hence shape (10,). In the second run enumeration can use the same three values (3,1) for all data points, and relies on broadcasting for any dependent sample or observe sites that depend on data. For example, in the pyro.sample("obs",...) statement, the distribution has shape (3,1), the data has shape(10,), and the broadcasted log probability tensor has shape (3,10).

For a more in-depth treatment of enumeration in mixture models, see the Gaussian Mixture Model Tutorial and the.

Dependencies among plates

The computational savings of enumerating in vectorized plates comes with restrictions on the dependency structure of models. These restrictions are in addition to the usual restrictions of conditional independence. The enumeration restrictions are checked by TraceEnum_ELBO and will result in an error if violated (however the usual conditional independence restriction cannot be generally verified by Pyro). For completeness we list all three restrictions:

Restriction 1: conditional independence

Variables within a plate may not depend on each other (along the plate dimension). This applies to any variable, whether or not it is enumerated. This applies to both sequential plates and vectorized plates. For example the following model is invalid:

def invalid_model():
    x = 0
    for i in pyro.plate("invalid", 10):
        x = pyro.sample("x_{}".format(i), dist.Normal(x, 1.))

Restriction 2: no downstream coupling

No variable outside of a vectorized plate can depend on an enumerated variable inside of that plate. This would violate Pyro’s exponential speedup assumption. For example the following model is invalid:

@config_enumerate
def invalid_model(data):
    with pyro.plate("plate", 10):  # <--- invalid vectorized plate
        x = pyro.sample("x", dist.Bernoulli(0.5))
    assert x.shape == (10,)
    pyro.sample("obs", dist.Normal(x.sum(), 1.), data)

To work around this restriction, you can convert the vectorized plate to a sequential plate:

@config_enumerate
def valid_model(data):
    x = []
    for i in pyro.plate("plate", 10):  # <--- valid sequential plate
        x.append(pyro.sample("x_{}".format(i), dist.Bernoulli(0.5)))
    assert len(x) == 10
    pyro.sample("obs", dist.Normal(sum(x), 1.), data)

Restriction 3: single path leaving each plate

The final restriction is subtle, but is required to enable Pyro’s exponential speedup

For any enumerated variable x, the set of all enumerated variables on which x depends must be linearly orderable in their vectorized plate nesting.

This requirement only applies when there are at least two plates and at least three variables in different plate contexts. The simplest counterexample is a Boltzmann machine

@config_enumerate
def invalid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10, dim=-2)  # vectorized
    with plate_1:
        x = pyro.sample("y", dist.Bernoulli(0.5))
    with plate_2:
        y = pyro.sample("x", dist.Bernoulli(0.5))
    with plate_1, plate2:
        z = pyro.sample("z", dist.Bernoulli((1. + x + y) / 4.))
        ...

Here we see that the variable z depends on variable x (which is in plate_1 but not plate_2) and depends on variable y (which is in plate_2 but not plate_1). This model is invalid because there is no way to linearly order x and y such that one’s plate nesting is less than the other.

To work around this restriction, you can convert one of the plates to a sequential plate:

@config_enumerate
def valid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10)          # sequential
    with plate_1:
        x = pyro.sample("y", dist.Bernoulli(0.5))
    for i in plate_2:
        y = pyro.sample("x_{}".format(i), dist.Bernoulli(0.5))
        with plate_1:
            z = pyro.sample("z_{}".format(i), dist.Bernoulli((1. + x + y) / 4.))
            ...

but beware that this increases the computational complexity, which may be exponential in the size of the sequential plate.

Time series example

Consider a discrete HMM with latent states \(x_t\) and observations \(y_t\). Suppose we want to learn the transition and emission probabilities.

In [8]:
data_dim = 4
num_steps = 10
data = dist.Categorical(torch.ones(num_steps, data_dim)).sample()

def hmm_model(data, data_dim, hidden_dim=10):
    print('Running for {} time steps'.format(len(data)))
    # Sample global matrices wrt a Jeffreys prior.
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in enumerate(data):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample("y_{}".format(t), dist.Categorical(emission[x]), obs=y)
        print("x_{}.shape = {}".format(t, x.shape))

We can learn the global parameters using SVI with an autoguide.

In [9]:
hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=["transition", "emission"]))

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
Running for 10 time steps
x_0.shape = torch.Size([])
x_1.shape = torch.Size([])
x_2.shape = torch.Size([])
x_3.shape = torch.Size([])
x_4.shape = torch.Size([])
x_5.shape = torch.Size([])
x_6.shape = torch.Size([])
x_7.shape = torch.Size([])
x_8.shape = torch.Size([])
x_9.shape = torch.Size([])
Running for 10 time steps
x_0.shape = torch.Size([10, 1])
x_1.shape = torch.Size([10, 1, 1])
x_2.shape = torch.Size([10, 1, 1, 1])
x_3.shape = torch.Size([10, 1, 1, 1, 1])
x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])
x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])
x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])
x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])
x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])
x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

Notice that the model was run twice here: first it was run without enumeration by AutoDiagonalNormal, so that the autoguide can record all sample sites; then second it is run by TraceEnum_ELBO with enumeration enabled. We see in the first run that samples have the standard interpretation, whereas in the second run samples have the enumeration interpretation.

For more complex examples, including minibatching and multiple plates, see the HMM tutorial.

How to enumerate more than 25 variables

PyTorch tensors have a dimension limit of 25 in CUDA and 64 in CPU. By default Pyro enumerates each sample site in a new dimension. If you need more sample sites, you can annotate your model with pyro.markov to tell Pyro when it is safe to recycle tensor dimensions. Let’s see how that works with the HMM model from above. The only change we need is to annotate the for loop with pyro.markov, informing Pyro that the variables in each step of the loop depend only on variables outside of the loop and variables at this step and the previous step of the loop:

- for t, y in enumerate(data):
+ for t, y in pyro.markov(enumerate(data)):
In [10]:
def hmm_model(data, data_dim, hidden_dim=10):
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in pyro.markov(enumerate(data)):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample("y_{}".format(t), dist.Categorical(emission[x]), obs=y)
        print("x_{}.shape = {}".format(t, x.shape))

# We'll reuse the same guide and elbo.
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
x_0.shape = torch.Size([10, 1])
x_1.shape = torch.Size([10, 1, 1])
x_2.shape = torch.Size([10, 1])
x_3.shape = torch.Size([10, 1, 1])
x_4.shape = torch.Size([10, 1])
x_5.shape = torch.Size([10, 1, 1])
x_6.shape = torch.Size([10, 1])
x_7.shape = torch.Size([10, 1, 1])
x_8.shape = torch.Size([10, 1])
x_9.shape = torch.Size([10, 1, 1])

Notice that this model now only needs three tensor dimensions: one for the plate, one for even states, and one for odd states. For more complex examples, see the Dynamic Bayes Net model in the HMM example.