SVI Part III: ELBO Gradient Estimators

Setup

We’ve defined a Pyro model with observations \({\bf x}\) and latents \({\bf z}\) of the form \(p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})\). We’ve also defined a Pyro guide (i.e. a variational distribution) of the form \(q_{\phi}({\bf z})\). Here \({\theta}\) and \(\phi\) are variational parameters for the model and guide, respectively. (In particular these are not random variables that call for a Bayesian treatment).

We’d like to maximize the log evidence \(\log p_{\theta}({\bf x})\) by maximizing the ELBO (the evidence lower bound) given by

\[{\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

To do this we’re going to take (stochastic) gradient steps on the ELBO in the parameter space \(\{ \theta, \phi \}\) (see references [1,2] for early work on this approach). So we need to be able to compute unbiased estimates of

\[\nabla_{\theta,\phi} {\rm ELBO} = \nabla_{\theta,\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

How do we do this for general stochastic functions model() and guide()? To simplify notation let’s generalize our discussion a bit and ask how we can compute gradients of expectations of an arbitrary cost function \(f({\bf z})\). Let’s also drop any distinction between \(\theta\) and \(\phi\). So we want to compute

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\]

Let’s start with the easiest case.

Easy Case: Reparameterizable Random Variables

Suppose that we can reparameterize things such that

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [f_{\phi}({\bf z}) \right] =\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

Crucially we’ve moved all the \(\phi\) dependence inside of the expectation; \(q({\bf \epsilon})\) is a fixed distribution with no dependence on \(\phi\). This kind of reparameterization can be done for many distributions (e.g. the normal distribution); see reference [3] for a discussion. In this case we can pass the gradient straight through the expectation to get

\[\nabla_{\phi}\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]= \mathbb{E}_{q({\bf \epsilon})} \left [\nabla_{\phi}f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

Assuming \(f(\cdot)\) and \(g(\cdot)\) are sufficiently smooth, we can now get unbiased estimates of the gradient of interest by taking a Monte Carlo estimate of this expectation.

Tricky Case: Non-reparameterizable Random Variables

What if we can’t do the above reparameterization? Unfortunately this is the case for many distributions of interest, for example all discrete distributions. In this case our estimator takes a bit more complicated form.

We begin by expanding the gradient of interest as

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]= \nabla_{\phi} \int d{\bf z} \; q_{\phi}({\bf z}) f_{\phi}({\bf z})\]

and use the chain rule to write this as

\[\int d{\bf z} \; \left \{ (\nabla_{\phi} q_{\phi}({\bf z})) f_{\phi}({\bf z}) + q_{\phi}({\bf z})(\nabla_{\phi} f_{\phi}({\bf z}))\right \}\]

At this point we run into a problem. We know how to generate samples from \(q(\cdot)\)—we just run the guide forward—but \(\nabla_{\phi} q_{\phi}({\bf z})\) isn’t even a valid probability density. So we need to massage this formula so that it’s in the form of an expectation w.r.t. \(q(\cdot)\). This is easily done using the identity

\[ \nabla_{\phi} q_{\phi}({\bf z}) = q_{\phi}({\bf z})\nabla_{\phi} \log q_{\phi}({\bf z})\]

which allows us to rewrite the gradient of interest as

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [ (\nabla_{\phi} \log q_{\phi}({\bf z})) f_{\phi}({\bf z}) + \nabla_{\phi} f_{\phi}({\bf z})\right]\]

This form of the gradient estimator—variously known as the REINFORCE estimator or the score function estimator or the likelihood ratio estimator—is amenable to simple Monte Carlo estimation.

Note that one way to package this result (which is convenient for implementation) is to introduce a surrogate objective function

\[{\rm surrogate \;objective} \equiv \log q_{\phi}({\bf z}) \overline{f_{\phi}({\bf z})} + f_{\phi}({\bf z})\]

Here the bar indicates that the term is held constant (i.e. it is not to be differentiated w.r.t. \(\phi\)). To get a (single-sample) Monte Carlo gradient estimate, we sample the latent random variables, compute the surrogate objective, and differentiate. The result is an unbiased estimate of \(\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\). In equations:

\[\nabla_{\phi} {\rm ELBO} = \mathbb{E}_{q_{\phi}({\bf z})} \left [ \nabla_{\phi} ({\rm surrogate \; objective}) \right]\]

Variance or Why I Wish I Was Doing MLE Deep Learning

We now have a general recipe for an unbiased gradient estimator of expectations of cost functions. Unfortunately, in the more general case where our \(q(\cdot)\) includes non-reparameterizable random variables, this estimator tends to have high variance. Indeed in many cases of interest the variance is so high that the estimator is effectively unusable. So we need strategies to reduce variance (for a discussion see reference [4]). We’re going to pursue two strategies. The first strategy takes advantage of the particular structure of the cost function \(f(\cdot)\). The second strategy effectively introduces a way to reduce variance by using information from previous estimates of \(\mathbb{E}_{q_{\phi}({\bf z})} [ f_{\phi}({\bf z})]\). As such it is somewhat analogous to using momentum in stochastic gradient descent.

Reducing Variance via Dependency Structure

In the above discussion we stuck to a general cost function \(f_{\phi}({\bf z})\). We could continue in this vein (the approach we’re about to discuss is applicable in the general case) but for concreteness let’s zoom back in. In the case of stochastic variational inference, we’re interested in a particular cost function of the form

\[\log p_{\theta}({\bf x} | {\rm Pa}_p ({\bf x})) + \sum_i \log p_{\theta}({\bf z}_i | {\rm Pa}_p ({\bf z}_i)) - \sum_i \log q_{\phi}({\bf z}_i | {\rm Pa}_q ({\bf z}_i))\]

where we’ve broken the log ratio \(\log p_{\theta}({\bf x}, {\bf z})/q_{\phi}({\bf z})\) into an observation log likelihood piece and a sum over the different latent random variables \(\{{\bf z}_i \}\). We’ve also introduced the notation \({\rm Pa}_p (\cdot)\) and \({\rm Pa}_q (\cdot)\) to denote the parents of a given random variable in the model and in the guide, respectively. (The reader might worry what the appropriate notion of dependency would be in the case of general stochastic functions; here we simply mean regular ol’ dependency within a single execution trace). The point is that different terms in the cost function have different dependencies on the random variables \(\{ {\bf z}_i \}\) and this is something we can leverage.

To make a long story short, for any non-reparameterizable latent random variable \({\bf z}_i\) the surrogate objective is going to have a term

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

It turns out that we can remove some of the terms in \(\overline{f_{\phi}({\bf z})}\) and still get an unbiased gradient estimator; furthermore, doing so will generally decrease the variance. In particular (see reference [4] for details) we can remove any terms in \(\overline{f_{\phi}({\bf z})}\) that are not downstream of the latent variable \({\bf z}_i\) (downstream w.r.t. to the dependency structure of the guide). Note that this general trick—where certain random variables are dealt with analytically to reduce variance—often goes under the name of Rao-Blackwellization.

In Pyro, all of this logic is taken care of automatically by the SVI class. In particular as long as we use a TraceGraph_ELBO loss, Pyro will keep track of the dependency structure within the execution traces of the model and guide and construct a surrogate objective that has all the unnecessary terms removed:

svi = SVI(model, guide, optimizer, TraceGraph_ELBO())

Note that leveraging this dependency information might have small computation overhead, so TraceGraph_ELBO should only be used in the case where your model has non-reparameterizable random variables; in most applications Trace_ELBO suffices.

An Example with Rao-Blackwellization:

Suppose we have a gaussian mixture model with \(K\) components. For each data point we: (i) first sample the component distribution \(k \in [1,...,K]\); and (ii) observe the data point using the \(k^{\rm th}\) component distribution. The simplest way to write down a model of this sort is as follows:

ks = pyro.sample("k", dist.Categorical(probs)
                          .to_event(1))
pyro.sample("obs", dist.Normal(locs[ks], scale)
                       .to_event(1),
            obs=data)

Since the user hasn’t taken care to mark any of the conditional independencies in the model, the gradient estimator constructed by Pyro’s SVI class is unable to take advantage of Rao-Blackwellization, with the result that the gradient estimator will tend to suffer from high variance. To address this problem the user needs to explicitly mark the conditional independence. Happily, this is not much work:

# mark conditional independence
# (assumed to be along the rightmost tensor dimension)
with pyro.plate("foo", data.size(-1)):
    ks = pyro.sample("k", dist.Categorical(probs))
    pyro.sample("obs", dist.Normal(locs[ks], scale),
                obs=data)

That’s all there is to it.

Aside: Dependency tracking in Pyro

Finally, a word about dependency tracking. Pyro uses the concept of provenance for tracking dependency within a stochastic function that includes arbitrary Python code (see reference [5]). In the programming language theory, the provenance of a variable refers to the history of variables or computations that contributed to its value. The simple example below demonstrates how provenance is tracked through PyTorch ops in Pyro, where provenance is a user-defined frozenset of objects:

from pyro.ops.provenance import get_provenance, track_provenance

a = track_provenance(torch.randn(3), frozenset({"a"}))
b = track_provenance(torch.randn(3), frozenset({"b"}))
c = torch.randn(3)  # no provenance information

# For a unary operation, the provenance of the output tensor
# equals the provenace of the input tensor
assert get_provenance(a.exp()) == frozenset({"a"})
# In general, the provenance of the output tensors of any op
# is the union of provenances of input tensors.
assert get_provenance(a * (b + c)) == frozenset({"a", "b"})

This concept is utilized by TraceGraph_ELBO to trace the fine-grained dynamic dependency information on non-reparameterizable random variables through intermediate computations as they come together to form a log-likelihood. Internally, non-reparameterizable sample sites are tracked using TrackNonReparam messenger:

def model():
    probs_a = torch.tensor([0.3, 0.7])
    probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])
    probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])
    a = pyro.sample("a", dist.Categorical(probs_a))
    b = pyro.sample("b", dist.Categorical(probs_b[a]))
    pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0))

with TrackNonReparam():
    model_tr = trace(model).get_trace()
model_tr.compute_log_prob()

assert get_provenance(model_tr.nodes["a"]["log_prob"]) == frozenset({'a'})
assert get_provenance(model_tr.nodes["b"]["log_prob"]) == frozenset({'b', 'a'})
assert get_provenance(model_tr.nodes["c"]["log_prob"]) == frozenset({'b', 'a'})

Reducing Variance with Data-Dependent Baselines

The second strategy for reducing variance in our ELBO gradient estimator goes under the name of baselines (see e.g. reference [6]). It actually makes use of the same bit of math that underlies the variance reduction strategy discussed above, except now instead of removing terms we’re going to add terms. Basically, instead of removing terms with zero expectation that tend to contribute to the variance, we’re going to add specially chosen terms with zero expectation that work to reduce the variance. As such, this is a control variate strategy.

In more detail, the idea is to take advantage of the fact that for any constant \(b\), the following identity holds

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} (\log q_{\phi}({\bf z}) \times b) \right]=0\]

This follows since \(q(\cdot)\) is normalized:

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} \log q_{\phi}({\bf z}) \right]= \int \!d{\bf z} \; q_{\phi}({\bf z}) \nabla_{\phi} \log q_{\phi}({\bf z})= \int \! d{\bf z} \; \nabla_{\phi} q_{\phi}({\bf z})= \nabla_{\phi} \int \! d{\bf z} \; q_{\phi}({\bf z})=\nabla_{\phi} 1 = 0\]

What this means is that we can replace any term

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

in our surrogate objective with

\[\log q_{\phi}({\bf z}_i) \left(\overline{f_{\phi}({\bf z})}-b\right)\]

Doing so doesn’t affect the mean of our gradient estimator but it does affect the variance. If we choose \(b\) wisely, we can hope to reduce the variance. In fact, \(b\) need not be a constant: it can depend on any of the random choices upstream (or sidestream) of \({\bf z}_i\).

Baselines in Pyro

There are several ways the user can instruct Pyro to use baselines in the context of stochastic variational inference. Since baselines can be attached to any non-reparameterizable random variable, the current baseline interface is at the level of the pyro.sample statement. In particular the baseline interface makes use of an argument baseline, which is a dictionary that specifies baseline options. Note that it only makes sense to specify baselines for sample statements within the guide (and not in the model).

Decaying Average Baseline

The simplest baseline is constructed from a running average of recent samples of \(\overline{f_{\phi}({\bf z})}\). In Pyro this kind of baseline can be invoked as follows

z = pyro.sample("z", dist.Bernoulli(...),
                infer=dict(baseline={'use_decaying_avg_baseline': True,
                                     'baseline_beta': 0.95}))

The optional argument baseline_beta specifies the decay rate of the decaying average (default value: 0.90).

Neural Baselines

In some cases a decaying average baseline works well. In others using a baseline that depends on upstream randomness is crucial for getting good variance reduction. A powerful approach for constructing such a baseline is to use a neural network that can be adapted during the course of learning. Pyro provides two ways to specify such a baseline (for an extended example see the AIR tutorial).

First the user needs to decide what inputs the baseline is going to consume (e.g. the current datapoint under consideration or the previously sampled random variable). Then the user needs to construct a nn.Module that encapsulates the baseline computation. This might look something like

class BaselineNN(nn.Module):
    def __init__(self, dim_input, dim_hidden):
        super().__init__()
        self.linear = nn.Linear(dim_input, dim_hidden)
        # ... finish initialization ...

    def forward(self, x):
        hidden = self.linear(x)
        # ... do more computations ...
        return baseline

Then, assuming the BaselineNN object baseline_module has been initialized somewhere else, in the guide we’ll have something like

def guide(x):  # here x is the current mini-batch of data
    pyro.module("my_baseline", baseline_module)
    # ... other computations ...
    z = pyro.sample("z", dist.Bernoulli(...),
                    infer=dict(baseline={'nn_baseline': baseline_module,
                                         'nn_baseline_input': x}))

Here the argument nn_baseline tells Pyro which nn.Module to use to construct the baseline. On the backend the argument nn_baseline_input is fed into the forward method of the module to compute the baseline \(b\). Note that the baseline module needs to be registered with Pyro with a pyro.module call so that Pyro is aware of the trainable parameters within the module.

Under the hood Pyro constructs a loss of the form

\[{\rm baseline\; loss} \equiv\left(\overline{f_{\phi}({\bf z})} - b \right)^2\]

which is used to adapt the parameters of the neural network. There’s no theorem that suggests this is the optimal loss function to use in this context (it’s not), but in practice it can work pretty well. Just as for the decaying average baseline, the idea is that a baseline that can track the mean \(\overline{f_{\phi}({\bf z})}\) will help reduce the variance. Under the hood SVI takes one step on the baseline loss in conjunction with a step on the ELBO.

Note that in practice it can be important to use a different set of learning hyperparameters (e.g. a higher learning rate) for baseline parameters. In Pyro this can be done as follows:

def per_param_args(param_name):
    if 'baseline' in param_name:
        return {"lr": 0.010}
    else:
        return {"lr": 0.001}

optimizer = optim.Adam(per_param_args)

Note that in order for the overall procedure to be correct the baseline parameters should only be optimized through the baseline loss. Similarly the model and guide parameters should only be optimized through the ELBO. To ensure that this is the case under the hood SVI detaches the baseline \(b\) that enters the ELBO from the autograd graph. Also, since the inputs to the neural baseline may depend on the parameters of the model and guide, the inputs are also detached from the autograd graph before they are fed into the neural network.

Finally, there is an alternate way for the user to specify a neural baseline. Simply use the argument baseline_value:

b = # do baseline computation
z = pyro.sample("z", dist.Bernoulli(...),
                infer=dict(baseline={'baseline_value': b}))

This works as above, except in this case it’s the user’s responsibility to make sure that any autograd tape connecting \(b\) to the parameters of the model and guide has been cut. Or to say the same thing in language more familiar to PyTorch users, any inputs to \(b\) that depend on \(\theta\) or \(\phi\) need to be detached from the autograd graph with detach() statements.

A complete example with baselines

Recall that in the first SVI tutorial we considered a bernoulli-beta model for coin flips. Because the beta random variable is non-reparameterizable (or rather not easily reparameterizable), the corresponding ELBO gradients can be quite noisy. In that context we dealt with this problem by using a Beta distribution that provides (approximate) reparameterized gradients. Here we showcase how a simple decaying average baseline can reduce the variance in the case where the Beta distribution is treated as non-reparameterized (so that the ELBO gradient estimator is of the score function type). While we’re at it, we also use plate to write our model in a fully vectorized manner.

Instead of directly comparing gradient variances, we’re going to see how many steps it takes for SVI to converge. Recall that for this particular model (because of conjugacy) we can compute the exact posterior. So to assess the utility of baselines in this context, we setup the following simple experiment. We initialize the guide at a specified set of variational parameters. We then do SVI until the variational parameters have gotten to within a fixed tolerance of the parameters of the exact posterior. We do this both with and without the decaying average baseline. We then compare the number of gradient steps we needed in the two cases. Here’s the complete code:

(Since apart from the use of plate and use_decaying_avg_baseline, this code is very similar to the code in parts I and II of the SVI tutorial, we’re not going to go through the code line by line.)

[ ]:
import os
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
# Pyro also has a reparameterized Beta distribution so we import
# the non-reparameterized version to make our point
from pyro.distributions.testing.fakes import NonreparameterizedBeta
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO
import sys

assert pyro.__version__.startswith('1.9.1')

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
max_steps = 2 if smoke_test else 10000


def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()


class BernoulliBetaExample:
    def __init__(self, max_steps):
        # the maximum number of inference steps we do
        self.max_steps = max_steps
        # the two hyperparameters for the beta prior
        self.alpha0 = 10.0
        self.beta0 = 10.0
        # the dataset consists of six 1s and four 0s
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n_data = self.data.size(0)
        # compute the alpha parameter of the exact beta posterior
        self.alpha_n = self.data.sum() + self.alpha0
        # compute the beta parameter of the exact beta posterior
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
        # initial values of the two variational parameters
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

    def model(self, use_decaying_avg_baseline):
        # sample `latent_fairness` from the beta prior
        f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        # use plate to indicate that the observations are
        # conditionally independent given f and get vectorization
        with pyro.plate("data_plate"):
            # observe all ten datapoints using the bernoulli likelihood
            pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

    def guide(self, use_decaying_avg_baseline):
        # register the two variational parameters with pyro
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
                            constraint=constraints.positive)
        # sample f from the beta variational distribution
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        # note that the baseline_dict specifies whether we're using
        # decaying average baselines or not
        pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        # clear the param store in case we're in a REPL
        pyro.clear_param_store()
        # setup the optimizer and the inference algorithm
        optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        # do up to this many steps of inference
        for k in range(self.max_steps):
            svi.step(use_decaying_avg_baseline)
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "were %.4f & %.4f") % (alpha_error, beta_error))

# do the experiment
bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Sample output:

Doing inference with use_decaying_avg_baseline=True
....................
Did 1932 steps of inference.
Final absolute errors for the two variational parameters were 0.7997 & 0.0800
Doing inference with use_decaying_avg_baseline=False
..................................................
Did 4908 steps of inference.
Final absolute errors for the two variational parameters were 0.7991 & 0.2532

For this particular run we can see that baselines roughly halved the number of steps of SVI we needed to do. The results are stochastic and will vary from run to run, but this is an encouraging result. This is a pretty contrived example, but for certain model and guide pairs, baselines can provide a substantial win.

References

[1] Automated Variational Inference in Probabilistic Programming,      David Wingate, Theo Weber

[2] Black Box Variational Inference,     Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] Auto-Encoding Variational Bayes,     Diederik P Kingma, Max Welling

[4] Gradient Estimation Using Stochastic Computation Graphs,      John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[5] Nonstandard Interpretations of Probabilistic Programs for Efficient Inference      David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

[6] Neural Variational Inference and Learning in Belief Networks      Andriy Mnih, Karol Gregor