SVI Part II: Conditional Independence, Subsampling, and Amortization

The Goal: Scaling SVI to Large Datasets

For a model with \(N\) observations, running the model and guide and constructing the ELBO involves evaluating log pdf’s whose complexity scales badly with \(N\). This is a problem if we want to scale to large datasets. Luckily, the ELBO objective naturally supports subsampling provided that our model/guide have some conditional independence structure that we can take advantage of. For example, in the case that the observations are conditionally independent given the latents, the log likelihood term in the ELBO can be approximated with

\[ \sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx \frac{N}{M} \sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})\]

where \(\mathcal{I}_M\) is a mini-batch of indices of size \(M\) with \(M<N\) (for a discussion please see references [1,2]). Great, problem solved! But how do we do this in Pyro?

Marking Conditional Independence in Pyro

If a user wants to do this sort of thing in Pyro, he or she first needs to make sure that the model and guide are written in such a way that Pyro can leverage the relevant conditional independencies. Let’s see how this is done. Pyro provides two language primitives for marking conditional independencies: plate and markov. Let’s start with the simpler of the two.

Sequential plate

Let’s return to the example we used in the previous tutorial. For convenience let’s replicate the main logic of model here:

def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data using pyro.sample with the obs keyword argument
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

For this model the observations are conditionally independent given the latent random variable latent_fairness. To explicitly mark this in Pyro we basically just need to replace the Python builtin range with the Pyro construct plate:

def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
    for i in pyro.plate("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

We see that pyro.plate is very similar to range with one main difference: each invocation of plate requires the user to provide a unique name. The second argument is an integer just like for range.

So far so good. Pyro can now leverage the conditional independency of the observations given the latent random variable. But how does this actually work? Basically pyro.plate is implemented using a context manager. At every execution of the body of the for loop we enter a new (conditional) independence context which is then exited at the end of the for loop body. Let’s be very explicit about this:

  • because each observed pyro.sample statement occurs within a different execution of the body of the for loop, Pyro marks each observation as independent

  • this independence is properly a conditional independence given latent_fairness because latent_fairness is sampled outside of the context of data_loop.

Before moving on, let’s mention some gotchas to be avoided when using sequential plate. Consider the following variant of the above code snippet:

# WARNING do not do this!
my_reified_list = list(pyro.plate("data_loop", len(data)))
for i in my_reified_list:
    pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

This will not achieve the desired behavior, since list() will enter and exit the data_loop context completely before a single pyro.sample statement is called. Similarly, the user needs to take care not to leak mutable computations across the boundary of the context manager, as this may lead to subtle bugs. For example, pyro.plate is not appropriate for temporal models where each iteration of a loop depends on the previous iteration; in this case a range or pyro.markov should be used instead.

Vectorized plate

Conceptually vectorized plate is the same as sequential plate except that it is a vectorized operation (as torch.arange is to range). As such it potentially enables large speed-ups compared to the explicit for loop that appears with sequential plate. Let’s see how this looks for our running example. First we need data to be in the form of a tensor:

data = torch.zeros(10)
data[0:6] = torch.ones(6)  # 6 heads and 4 tails

Then we have:

with pyro.plate('observe_data'):
    pyro.sample('obs', dist.Bernoulli(f), obs=data)

Let’s compare this to the analogous sequential plate usage point-by-point:

  • both patterns requires the user to specify a unique name.

  • note that this code snippet only introduces a single (observed) random variable (namely obs), since the entire tensor is considered at once.

  • since there is no need for an iterator in this case, there is no need to specify the length of the tensor(s) involved in the plate context

Note that the gotchas mentioned in the case of sequential plate also apply to vectorized plate.

Subsampling

We now know how to mark conditional independence in Pyro. This is useful in and of itself (see the dependency tracking section in SVI Part III), but we’d also like to do subsampling so that we can do SVI on large datasets. Depending on the structure of the model and guide, Pyro supports several ways of doing subsampling. Let’s go through these one by one.

Automatic subsampling with plate

Let’s look at the simplest case first, in which we get subsampling for free with one or two additional arguments to plate:

for i in pyro.plate("data_loop", len(data), subsample_size=5):
    pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

That’s all there is to it: we just use the argument subsample_size. Whenever we run model() we now only evaluate the log likelihood for 5 randomly chosen datapoints in data; in addition, the log likelihood will be automatically scaled by the appropriate factor of \(\tfrac{10}{5} = 2\). What about vectorized plate? The incantation is entirely analogous:

with pyro.plate('observe_data', size=10, subsample_size=5) as ind:
    pyro.sample('obs', dist.Bernoulli(f),
                obs=data.index_select(0, ind))

Importantly, plate now returns a tensor of indices ind, which, in this case will be of length 5. Note that in addition to the argument subsample_size we also pass the argument size so that plate is aware of the full size of the tensor data so that it can compute the correct scaling factor. Just like for sequential plate, the user is responsible for selecting the correct datapoints using the indices provided by plate.

Finally, note that the user must pass a device argument to plate if data is on the GPU.

Custom subsampling strategies with plate

Every time the above model() is run plate will sample new subsample indices. Since this subsampling is stateless, this can lead to some problems: basically for a sufficiently large dataset even after a large number of iterations there’s a nonnegligible probability that some of the datapoints will have never been selected. To avoid this the user can take control of subsampling by making use of the subsample argument to plate. See the docs for details.

Subsampling when there are only local random variables

We have in mind a model with a joint probability density given by

\[p({\bf x}, {\bf z}) = \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i)\]

For a model with this dependency structure the scale factor introduced by subsampling scales all the terms in the ELBO by the same amount. This is the case, for example, for a vanilla VAE. This explains why for the VAE it’s permissible for the user to take complete control over subsampling and pass mini-batches directly to the model and guide; plate is still used, but subsample_size and subsample are not. To see how this looks in detail, see the VAE tutorial.

Subsampling when there are both global and local random variables

In the coin flip examples above plate appeared in the model but not in the guide, since the only thing being subsampled was the observations. Let’s look at a more complicated example where subsampling appears in both the model and guide. To make things simple let’s keep the discussion somewhat abstract and avoid writing a complete model and guide.

Consider the model specified by the following joint distribution:

\[ p({\bf x}, {\bf z}, \beta) = p(\beta) \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta)\]

There are \(N\) observations \(\{ {\bf x}_i \}\) and \(N\) local latent random variables \(\{ {\bf z}_i \}\). There is also a global latent random variable \(\beta\). Our guide will be factorized as

\[q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)\]

Here we’ve been explicit about introducing \(N\) local variational parameters \(\{\lambda_i \}\), while the other variational parameters are left implicit. Both the model and guide have conditional independencies. In particular, on the model side, given the \(\{ {\bf z}_i \}\) the observations \(\{ {\bf x}_i \}\) are independent. In addition, given \(\beta\) the latent random variables \(\{\bf {z}_i \}\) are independent. On the guide side, given the variational parameters \(\{\lambda_i \}\) and \(\beta\) the latent random variables \(\{\bf {z}_i \}\) are independent. To mark these conditional independencies in Pyro and do subsampling we need to make use of plate in both the model and the guide. Let’s sketch out the basic logic using sequential plate (a more complete piece of code would include pyro.param statements, etc.). First, the model:

def model(data):
    beta = pyro.sample("beta", ...) # sample the global RV
    for i in pyro.plate("locals", len(data)):
        z_i = pyro.sample("z_{}".format(i), ...)
        # compute the parameter used to define the observation
        # likelihood using the local random variable
        theta_i = compute_something(z_i)
        pyro.sample("obs_{}".format(i), dist.MyDist(theta_i), obs=data[i])

Note that in contrast to our running coin flip example, here we have pyro.sample statements both inside and outside of the plate loop. Next the guide:

def guide(data):
    beta = pyro.sample("beta", ...) # sample the global RV
    for i in pyro.plate("locals", len(data), subsample_size=5):
        # sample the local RVs
        pyro.sample("z_{}".format(i), ..., lambda_i)

Note that crucially the indices will only be subsampled once in the guide; the Pyro backend makes sure that the same set of indices are used during execution of the model. For this reason subsample_size only needs to be specified in the guide.

Amortization

Let’s again consider a model with global and local latent random variables and local variational parameters:

\[ p({\bf x}, {\bf z}, \beta) = p(\beta) \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta) \qquad \qquad q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)\]

For small to medium-sized \(N\) using local variational parameters like this can be a good approach. If \(N\) is large, however, the fact that the space we’re doing optimization over grows with \(N\) can be a real problem. One way to avoid this nasty growth with the size of the dataset is amortization.

This works as follows. Instead of introducing local variational parameters, we’re going to learn a single parametric function \(f(\cdot)\) and work with a variational distribution that has the form

\[q(\beta) \prod_{n=1}^N q({\bf z}_i | f({\bf x}_i))\]

The function \(f(\cdot)\)—which basically maps a given observation to a set of variational parameters tailored to that datapoint—will need to be sufficiently rich to capture the posterior accurately, but now we can handle large datasets without having to introduce an obscene number of variational parameters. This approach has other benefits too: for example, during learning \(f(\cdot)\) effectively allows us to share statistical power among different datapoints. Note that this is precisely the approach used in the VAE.

Tensor shapes and vectorized plate

The usage of pyro.plate in this tutorial was limited to relatively simple cases. For example, none of the plates were nested inside of other plates. In order to make full use of plate, the user must be careful to use Pyro’s tensor shape semantics. For a discussion see the tensor shapes tutorial.

References

[1] Stochastic Variational Inference,      Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley

[2] Auto-Encoding Variational Bayes,     Diederik P Kingma, Max Welling