SVI Part III: ELBO Gradient Estimators

Setup

We’ve defined a Pyro model with observations \({\bf x}\) and latents \({\bf z}\) of the form \(p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})\). We’ve also defined a Pyro guide (i.e. a variational distribution) of the form \(q_{\phi}({\bf z})\). Here \({\theta}\) and \(\phi\) are variational parameters for the model and guide, respectively. (In particular these are not random variables that call for a Bayesian treatment).

We’d like to maximize the log evidence \(\log p_{\theta}({\bf x})\) by maximizing the ELBO (the evidence lower bound) given by

\[{\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

To do this we’re going to take (stochastic) gradient steps on the ELBO in the parameter space \(\{ \theta, \phi \}\) (see references [1,2] for early work on this approach). So we need to be able to compute unbiased estimates of

\[\nabla_{\theta,\phi} {\rm ELBO} = \nabla_{\theta,\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

How do we do this for general stochastic functions model() and guide()? To simplify notation let’s generalize our discussion a bit and ask how we can compute gradients of expectations of an arbitrary cost function \(f({\bf z})\). Let’s also drop any distinction between \(\theta\) and \(\phi\). So we want to compute

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\]

Let’s start with the easiest case.

Easy Case: Reparameterizable Random Variables

Suppose that we can reparameterize things such that

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [f_{\phi}({\bf z}) \right] =\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

Crucially we’ve moved all the \(\phi\) dependence inside of the exectation; \(q({\bf \epsilon})\) is a fixed distribution with no dependence on \(\phi\). This kind of reparameterization can be done for many distributions (e.g. the normal distribution); see reference [3] for a discussion. In this case we can pass the gradient straight through the expectation to get

\[\nabla_{\phi}\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]= \mathbb{E}_{q({\bf \epsilon})} \left [\nabla_{\phi}f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

Assuming \(f(\cdot)\) and \(g(\cdot)\) are sufficiently smooth, we can now get unbiased estimates of the gradient of interest by taking a Monte Carlo estimate of this expectation.

Tricky Case: Non-reparameterizable Random Variables

What if we can’t do the above reparameterization? Unfortunately this is the case for many distributions of interest, for example all discrete distributions. In this case our estimator takes a bit more complicated form.

We begin by expanding the gradient of interest as

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]= \nabla_{\phi} \int d{\bf z} \; q_{\phi}({\bf z}) f_{\phi}({\bf z})\]

and use the chain rule to write this as

\[\int d{\bf z} \; \left \{ (\nabla_{\phi} q_{\phi}({\bf z})) f_{\phi}({\bf z}) + q_{\phi}({\bf z})(\nabla_{\phi} f_{\phi}({\bf z}))\right \}\]

At this point we run into a problem. We know how to generate samples from \(q(\cdot)\)—we just run the guide forward—but \(\nabla_{\phi} q_{\phi}({\bf z})\) isn’t even a valid probability density. So we need to massage this formula so that it’s in the form of an expectation w.r.t. \(q(\cdot)\). This is easily done using the identity

\[ \nabla_{\phi} q_{\phi}({\bf z}) = q_{\phi}({\bf z})\nabla_{\phi} \log q_{\phi}({\bf z})\]

which allows us to rewrite the gradient of interest as

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [ (\nabla_{\phi} \log q_{\phi}({\bf z})) f_{\phi}({\bf z}) + \nabla_{\phi} f_{\phi}({\bf z})\right]\]

This form of the gradient estimator—variously known as the REINFORCE estimator or the score function estimator or the likelihood ratio estimator—is amenable to simple Monte Carlo estimation.

Note that one way to package this result (which is covenient for implementation) is to introduce a surrogate loss function

\[{\rm surrogate \;loss} \equiv \log q_{\phi}({\bf z}) \overline{f_{\phi}({\bf z})} + f_{\phi}({\bf z})\]

Here the bar indicates that the term is held constant (i.e. it is not to be differentiated w.r.t. \(\phi\)). To get a (single-sample) Monte Carlo gradient estimate, we sample the latent random variables, compute the surrogate loss, and differentiate. The result is an unbiased estimate of \(\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\). In equations:

\[\nabla_{\phi} {\rm ELBO} = \mathbb{E}_{q_{\phi}({\bf z})} \left [ \nabla_{\phi} ({\rm surrogate \; loss}) \right]\]

Variance or Why I Wish I Was Doing MLE Deep Learning

We now have a general recipe for an unbiased gradient estimator of expectations of cost functions. Unfortunately, in the more general case where our \(q(\cdot)\) includes non-reparameterizable random variables, this estimator tends to have high variance. Indeed in many cases of interest the variance is so high that the estimator is effectively unusable. So we need strategies to reduce variance (for a discussion see reference [4]). We’re going to pursue two strategies. The first strategy takes advantage of the particular structure of the cost function \(f(\cdot)\). The second strategy effectively introduces a way to reduce variance by using information from previous estimates of \(\mathbb{E}_{q_{\phi}({\bf z})} [ f_{\phi}({\bf z})]\). As such it is somewhat analogous to using momentum in stochastic gradient descent.

Reducing Variance via Dependency Structure

In the above discussion we stuck to a general cost function \(f_{\phi}({\bf z})\). We could continue in this vein (the approach we’re about to discuss is applicable in the general case) but for concreteness let’s zoom back in. In the case of stochastic variational inference, we’re interested in a particular cost function of the form

\[\log p_{\theta}({\bf x} | {\rm Pa}_p ({\bf x})) + \sum_i \log p_{\theta}({\bf z}_i | {\rm Pa}_p ({\bf z}_i)) - \sum_i \log q_{\phi}({\bf z}_i | {\rm Pa}_q ({\bf z}_i))\]

where we’ve broken the log ratio \(\log p_{\theta}({\bf x}, {\bf z})/q_{\phi}({\bf z})\) into an observation log likelihood piece and a sum over the different latent random variables \(\{{\bf z}_i \}\). We’ve also introduced the notation \({\rm Pa}_p (\cdot)\) and \({\rm Pa}_q (\cdot)\) to denote the parents of a given random variable in the model and in the guide, respectively. (The reader might worry what the appropriate notion of dependency would be in the case of general stochastic functions; here we simply mean regular ol’ dependency within a single execution trace). The point is that different terms in the cost function have different dependencies on the random variables \(\{ {\bf z}_i \}\) and this is something we can leverage.

To make a long story short, for any non-reparameterizable latent random variable \({\bf z}_i\) the surrogate loss is going to have a term

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

It turns out that we can remove some of the terms in \(\overline{f_{\phi}({\bf z})}\) and still get an unbiased gradient estimator; furthermore, doing so will generally decrease the variance. In particular (see reference [4] for details) we can remove any terms in \(\overline{f_{\phi}({\bf z})}\) that are not downstream of the latent variable \({\bf z}_i\) (downstream w.r.t. to the dependency structure of the guide).

In Pyro, all of this logic is taken care of automatically by the SVI class. In particular as long as we switch on trace_graph=True, Pyro will keep track of the dependency structure within the execution traces of the model and guide and construct a surrogate loss that has all the unnecessary terms removed:

svi = SVI(model, guide, optimizer, "ELBO", trace_graph=True)

Note that leveraging this dependency information takes extra computations, so trace_graph=True should only be invoked in the case where your model has non-reparameterizable random variables.

Aside: Dependency tracking in Pyro

Finally, a word about dependency tracking. Tracking dependency within a stochastic function that includes arbitrary Python code is a bit tricky. The approach currently implemented in Pyro is analogous to the one used in WebPPL (cf. reference [5]). Briefly, a conservative notion of dependency is used that relies on sequential ordering. If random variable \({\bf z}_2\) follows \({\bf z}_1\) in a given stochastic function then \({\bf z}_2\) may be dependent on \({\bf z}_1\) and therefore is assumed to be dependent. To mitigate the overly coarse conclusions that can be drawn by this kind of dependency tracking, Pyro includes constructs for declaring things as independent, namely irange and iarange (see the previous tutorial). For use cases with non-reparameterizable variables, it is therefore important for the user to make use of these constructs (when applicable) to take full advantage of the variance reduction provided by SVI. In some cases it may also pay to consider reordering random variables within a stochastic function (if possible). It’s also worth noting that we expect to add finer notions of dependency tracking in a future version of Pyro.

Reducing Variance with Data-Dependent Baselines

The second strategy for reducing variance in our ELBO gradient estimator goes under the name of baselines (see e.g. reference [6]). It actually makes use of the same bit of math that underlies the variance reduction strategy discussed above, except now instead of removing terms we’re going to add terms. Basically, instead of removing terms with zero expectation that tend to contribute to the variance, we’re going to add specially chosen terms with zero expectation that work to reduce the variance. As such, this is a control variate strategy.

In more detail, the idea is to take advantage of the fact that for any constant \(b\), the following identity holds

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} (\log q_{\phi}({\bf z}) \times b) \right]=0\]

This follows since \(q(\cdot)\) is normalized:

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} \log q_{\phi}({\bf z}) \right]= \int \!d{\bf z} \; q_{\phi}({\bf z}) \nabla_{\phi} \log q_{\phi}({\bf z})= \int \! d{\bf z} \; \nabla_{\phi} q_{\phi}({\bf z})= \nabla_{\phi} \int \! d{\bf z} \; q_{\phi}({\bf z})=\nabla_{\phi} 1 = 0\]

What this means is that we can replace any term

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

in our surrogate loss with

\[\log q_{\phi}({\bf z}_i) \left(\overline{f_{\phi}({\bf z})}-b\right)\]

Doing so doesn’t affect the mean of our gradient estimator but it does affect the variance. If we choose \(b\) wisely, we can hope to reduce the variance. In fact, \(b\) need not be a constant: it can depend on any of the random choices upstream (or sidestream) of \({\bf z}_i\).

Baselines in Pyro

There are several ways the user can instruct Pyro to use baselines in the context of stochastic variational inference. Since baselines can be attached to any non-reparameterizable random variable, the current baseline interface is at the level of the pyro.sample statement. In particular the baseline interface makes use of an argument baseline, which is a dictionary that specifies baseline options. Note that it only makes sense to specify baselines for sample statements within the guide (and not in the model).

Decaying Average Baseline

The simplest baseline is constructed from a running average of recent samples of \(\overline{f_{\phi}({\bf z})}\). In Pyro this kind of baseline can be invoked as follows

z = pyro.sample("z", dist.bernoulli, ...,
                baseline={'use_decaying_avg_baseline': True,
                          'baseline_beta': 0.95})

The optional argument baseline_beta specifies the decay rate of the decaying average (default value: 0.90).

Neural Baselines

In some cases a decaying average baseline works well. In others using a baseline that depends on upstream randomness is crucial for getting good variance reduction. A powerful approach for constructing such a baseline is to use a neural network that can be adapted during the course of learning. Pyro provides two ways to specify such a baseline (for an extended example see the AIR tutorial).

First the user needs to decide what inputs the baseline is going to consume (e.g. the current datapoint under consideration or the previously sampled random variable). Then the user needs to construct a nn.Module that encapsulates the baseline computation. This might look something like

class BaselineNN(nn.Module):
    def __init__(self, dim_input, dim_hidden):
        super(BaselineNN, self).__init__()
        self.linear = nn.Linear(dim_input, dim_hidden)
        # ... finish initialization ...

    def forward(self, x):
        hidden = self.linear(x)
        # ... do more computations ...
        return baseline

Then, assuming the BaselineNN object baseline_module has been initialized somewhere else, in the guide we’ll have something like

def guide(x):  # here x is the current mini-batch of data
    pyro.module("my_baseline", baseline_module, tags="baseline")
    # ... other computations ...
    z = pyro.sample("z", dist.bernoulli, ...,
                    baseline={'nn_baseline': baseline_module,
                              'nn_baseline_input': x})

Here the argument nn_baseline tells Pyro which nn.Module to use to construct the baseline. On the backend the argument nn_baseline_input is fed into the forward method of the module to compute the baseline \(b\). Note that the baseline module needs to be registered with Pyro with a pyro.module call so that Pyro is aware of the trainable parameters within the module.

Under the hood Pyro constructs a loss of the form

\[{\rm baseline\; loss} \equiv\left(\overline{f_{\phi}({\bf z})} - b \right)^2\]

which is used to adapt the parameters of the neural network. There’s no theorem that suggests this is the optimal loss function to use in this context (it’s not), but in practice it can work pretty well. Just as for the decaying average baseline, the idea is that a baseline that can track the mean \(\overline{f_{\phi}({\bf z})}\) will help reduce the variance. Under the hood SVI takes one step on the baseline loss in conjunction with a step on the ELBO.

Note that the module baseline_module has been tagged with the string "baseline" above; this has the effect of tagging all parameters inside of baseline_module with the parameter tag "baseline". This gives the user a convenient handle for controlling how the baseline parameters are optimized. For example, if the user wants the baseline parameters to have a larger learning rate (usually a good idea) an appropriate optimizer could be constructed as follows:

def per_param_args(module_name, param_name, tags):
    if 'baseline' in tags:
        return {"lr": 0.010}
    else:
        return {"lr": 0.001}

optimizer = optim.Adam(per_param_args)

Note that in order for the overall procedure to be correct the baseline parameters should only be optimized through the baseline loss. Similarly the model and guide parameters should only be optimized through the ELBO. To ensure that this is the case under the hood SVI detaches the baseline \(b\) that enters the ELBO from the autograd graph. Also, since the inputs to the neural baseline may depend on the parameters of the model and guide, the inputs are also detached from the autograd graph before they are fed into the neural network.

Finally, there is an alternate way for the user to specify a neural baseline. Simply use the argument baseline_value:

b = # do baseline computation
z = pyro.sample("z", dist.bernoulli, ...,
                baseline={'baseline_value': b})

This works as above, except in this case it’s the user’s responsibility to make sure that any autograd tape connecting \(b\) to the parameters of the model and guide has been cut. Or to say the same thing in language more familiar to PyTorch users, any inputs to \(b\) that depend on \(\theta\) or \(\phi\) need to be detached from the autograd graph with detach() statements.

A complete example with baselines

Recall that in the first SVI tutorial we considered a bernoulli-beta model for coin flips. Because the beta random variable is non-reparameterizable, the corresponding ELBO gradients are quite noisy. In that context we dealt with this problem by dialing up the number of Monte Carlo samples used to form the estimator. This isn’t necessarily a bad approach, but it can be an expensive one. Here we showcase how a simple decaying average baseline can reduce the variance. While we’re at it, we also use iarange to write our model in a fully vectorized manner.

Instead of directly comparing gradient variances, we’re going to see how many steps it takes for SVI to converge. Recall that for this particular model (because of conjugacy) we can compute the exact posterior. So to assess the utility of baselines in this context, we setup the following simple experiment. We initialize the guide at a specified set of variational parameters. We then do SVI until the variational parameters have gotten to within a fixed tolerance of the parameters of the exact posterior. We do this both with and without the decaying average baseline. We then compare the number of gradient steps we needed in the two cases. Here’s the complete code:

(Since apart from the use of iarange and use_decaying_avg_baseline, this code is very similar to the code in parts I and II of the SVI tutorial, we’re not going to go through the code line by line.)

In [ ]:
from __future__ import print_function
import numpy as np
import torch
from torch.autograd import Variable
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI
import sys


def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).data.numpy()[0]


class BernoulliBetaExample(object):
    def __init__(self):
        # the two hyperparameters for the beta prior
        self.alpha0 = Variable(torch.Tensor([10.0]))
        self.beta0 = Variable(torch.Tensor([10.0]))
        # the dataset consists of six 1s and four 0s
        self.data = Variable(torch.zeros(10,1))
        self.data[0:6, 0].data = torch.ones(6)
        self.n_data = self.data.size(0)
        # compute the alpha parameter of the exact beta posterior
        self.alpha_n = self.alpha0 + self.data.sum()
        # compute the beta parameter of the exact beta posterior
        self.beta_n = self.beta0 - self.data.sum() + Variable(torch.Tensor([self.n_data]))
        # for convenience compute the logs
        self.log_alpha_n = torch.log(self.alpha_n)
        self.log_beta_n = torch.log(self.beta_n)

    def setup(self):
        # initialize values of the two variational parameters
        # set to be quite close to the true values
        # so that the experiment doesn't take too long
        self.log_alpha_q_0 = Variable(torch.Tensor([np.log(15.0)]), requires_grad=True)
        self.log_beta_q_0 = Variable(torch.Tensor([np.log(15.0)]), requires_grad=True)

    def model(self, use_decaying_avg_baseline):
        # sample `latent_fairness` from the beta prior
        f = pyro.sample("latent_fairness", dist.beta, self.alpha0, self.beta0)
        # use iarange to indicate that the observations are
        # conditionally independent given f and get vectorization
        with pyro.iarange("data_iarange"):
            # observe all ten datapoints using the bernoulli likelihood
            pyro.observe("obs", dist.bernoulli, self.data, f)

    def guide(self, use_decaying_avg_baseline):
        # register the two variational parameters with pyro
        log_alpha_q = pyro.param("log_alpha_q", self.log_alpha_q_0)
        log_beta_q = pyro.param("log_beta_q", self.log_beta_q_0)
        alpha_q, beta_q = torch.exp(log_alpha_q), torch.exp(log_beta_q)
        # sample f from the beta variational distribution
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        # note that the baseline_dict specifies whether we're using
        # decaying average baselines or not
        pyro.sample("latent_fairness", dist.beta, alpha_q, beta_q,
                    baseline=baseline_dict)

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.05):
        # clear the param store in case we're in a REPL
        pyro.clear_param_store()
        # initialize the variational parameters for this run
        self.setup()
        # setup the optimizer and the inference algorithm
        optimizer = optim.Adam({"lr": .0008, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss="ELBO", trace_graph=True)
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        # do up to 10000 steps of inference
        for k in range(10000):
            svi.step(use_decaying_avg_baseline)
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("log_alpha_q", self.log_alpha_n)
            beta_error = param_abs_error("log_beta_q", self.log_beta_n)

            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "(in log space) were %.4f & %.4f") % (alpha_error, beta_error))

# do the experiment
bbe = BernoulliBetaExample()
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Sample output:

Doing inference with use_decaying_avg_baseline=True
...........
Did 2070 steps of inference.
Final absolute errors for the two variational parameters (in log space) were 0.0500 & 0.0443
Doing inference with use_decaying_avg_baseline=False
.....................
Did 4159 steps of inference.
Final absolute errors for the two variational parameters (in log space) were 0.0500 & 0.0306

For this particular run we can see that baselines roughly halved the number of steps of SVI we needed to do. The results are stochastic and will vary from run to run, but this is an encouraging result. For certain model and guide pairs, baselines can provide an even bigger win.

References

[1] Automated Variational Inference in Probabilistic Programming,      David Wingate, Theo Weber

[2] Black Box Variational Inference,     Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] Auto-Encoding Variational Bayes,     Diederik P Kingma, Max Welling

[4] Gradient Estimation Using Stochastic Computation Graphs, John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[5] Deep Amortized Inference for Probabilistic Programs      Daniel Ritchie, Paul Horsfall, Noah D. Goodman

[6] Neural Variational Inference and Learning in Belief Networks      Andriy Mnih, Karol Gregor