SVI Part IV: Tips and Tricks

The three SVI tutorials leading up to this one (Part I, Part II, & Part III) go through the various steps involved in using Pyro to do variational inference. Along the way we defined models and guides (i.e. variational distributions), setup variational objectives (in particular ELBOs), and constructed optimizers (pyro.optim). The effect of all this machinery is to cast Bayesian inference as a stochastic optimization problem.

This is all very useful, but in order to arrive at our ultimate goal—learning model parameters, inferring approximate posteriors, making predictions with the posterior predictive distribution, etc.—we need to successfully solve this optimization problem. Depending on the details of the particular problem—for example the dimensionality of the latent space, whether we have discrete latent variables, and so on—this can be easy or hard. In this tutorial we cover a few tips and tricks we expect to be generally useful for users doing variational inference in Pyro. ELBO not converging!? Running into NaNs!? Look below for possible solutions!

Pyro Forum

If you’re still having trouble with optimization after reading this tutorial, please don’t hesitate to ask a question on our forum!

1. Start with a small learning rate

While large learning rates might be appropriate for some problems, it’s usually good practice to start with small learning rates like \(10^{-3}\) or \(10^{-4}\):

optimizer = pyro.optim.Adam({"lr": 0.001})

This is because ELBO gradients are stochastic, and potentially high variance, so large learning rates can quickly lead to regions of model/guide parameter space that are numerically unstable or otherwise undesirable.

You can try a larger learning rate once you have achieved stable ELBO optimization using a smaller learning rate. This is often a good idea because excessively small learning rates can lead to poor optimization. In particular small learning rates can lead to getting stuck in poor local optima of the ELBO.

2. Use Adam or ClippedAdam by default

Use Adam or ClippedAdam by default when doing Stochastic Variational Inference. Note that ClippedAdam is just a convenient extension of Adam that provides built-in support for learning rate decay and gradient clipping.

The basic reason these optimization algorithms often do well in the context of variational inference is that the smoothing they provide via per-parameter momentum is often essential when the optimization problem is very stochastic. Note that in SVI stochasticity can come from sampling latent variables, from subsampling data, or from both.

In addition to tuning the learning rate in some cases it may be necessary to also tune the pair of betas hyperparameters that controls the momentum used by Adam. In particular for very stochastic models it may make sense to use higher values of \(\beta_1\):

betas = (0.95, 0.999)

instead of

betas = (0.90, 0.999)

3. Consider using a decaying learning rate

While a moderately large learning rate can be useful at the beginning of optimization when you’re far from the optimum and want to take large gradient steps, it’s often useful to have a smaller learning rate later on so that you don’t bounce around the optimum excessively without converging. One way to do this is to use the learning rate schedulers provided by Pyro. For example usage see the code snippet here. Another convenient way to do this is to use the ClippedAdam optimizer that has built-in support for learning rate decay via the lrd argument:

num_steps = 1000
initial_lr = 0.001
gamma = 0.1  # final learning rate will be gamma * initial_lr
lrd = gamma ** (1 / num_steps)
optim = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd})

4. Make sure your model and guide distributions have the same support

Suppose you have a distribution in your model with constrained support, e.g. a LogNormal distribution, which has support on the positive real axis:

def model():
    pyro.sample("x", dist.LogNormal(0.0, 1.0))

Then you need to ensure that the accompanying sample site in the guide has the same support:

def good_guide():
    loc = pyro.param("loc", torch.tensor(0.0))
    pyro.sample("x", dist.LogNormal(loc, 1.0))

If you fail to do this and use for example the following inadmissable guide:

def bad_guide():
    loc = pyro.param("loc", torch.tensor(0.0))
    # Normal may sample x < 0
    pyro.sample("x", dist.Normal(loc, 1.0))

you will likely run into NaNs very quickly. This is because the log_prob of a LogNormal distribution evaluated at a sample x that satisfies x<0 is undefined, and the bad_guide is likely to produce such samples.

5. Constrain parameters that need to be constrained

In a similar vein, you need to make sure that the parameters used to instantiate distributions are valid; otherwise you will quickly run into NaNs. For example the scale parameter of a Normal distribution needs to be positive. Thus the following bad_guide is problematic:

def bad_guide():
    scale = pyro.param("scale", torch.tensor(1.0))
    pyro.sample("x", dist.Normal(0.0, scale))

while the following good_guide correctly uses a constraint to ensure positivity:

from pyro.distributions import constraints

def good_guide():
    scale = pyro.param("scale", torch.tensor(0.05),
                        constraint=constraints.positive)
    pyro.sample("x", dist.Normal(0.0, scale))

6. If you are having trouble constructing a custom guide, use an AutoGuide

In order for a model/guide pair to lead to stable optimization a number of conditions need to be satisfied, some of which we have covered above. Sometimes it can be difficult to diagnose the reason for numerical instability or poor convergence. Among other reasons this is because the fundamental issue could arise in a number of different places: in the model, in the guide, or in the choice of optimization algorithm or hyperparameters.

Sometimes the problem is actually in your model even though you think it’s in the guide. Conversely, sometimes the problem is in your guide even though you think it’s in the model or somewhere else. For these reasons it can be helpful to reduce the number of moving parts while you try to identify the underlying issue. One convenient way to do this is to replace your custom guide with a pyro.infer.AutoGuide.

For example, if all the latent variables in your model are continuous, you can try a pyro.infer.AutoNormal guide. Alternatively, you can use MAP inference instead of full-blown variational inference. See the MLE/MAP tutorial for further details. Once you have MAP inference working, there’s good reason to believe that your model is setup correctly (at least as far as basic numerical stability is concerned). If you’re interested in obtaining approximate posterior distributions, you can now follow-up with full-blown SVI. Indeed a natural order of operations might use the following sequence of increasingly flexible autoguides:

AutoDeltaAutoNormalAutoLowRankMultivariateNormal

If you find that you want a more flexible guide or that you want to take more control over how exactly the guide is defined, at this juncture you can proceed to build a custom guide. One way to go about doing this is to leverage easy guides, which strike a balance between the control of a fully custom guide and the automation of an autoguide.

Also note that autoguides offer several initialization strategies and it may be necessary in some cases to experiment with these in order to get good optimization performance. One way to control initialization behavior is using the init_loc_fn. For example usage of init_loc_fn, including example usage for the easy guide API, see here.

7. Parameter initialization matters: initialize guide distributions to have low variance

Initialization in optimization problems can make all the difference between finding a good solution and failing catastrophically. It is difficult to come up with a comprehensive set of good practices for initialization, as good initialization schemes are often very problem dependent. In the context of Stochastic Variational Inference it is generally a good idea to initialize your guide distributions so that they have low variance. This is because the ELBO gradients you use to optimize the ELBO are stochastic. If the ELBO gradients you get at the beginning of ELBO optimization exhibit high variance, you may be led into numerically unstable or otherwise undesirable regions of parameter space. One way to guard against this potential hazard is to pay close attention to parameters in your guide that control variance. For example we would generally expect this to be a reasonably initialized guide:

from pyro.distributions import constraints

def good_guide():
    scale = pyro.param("scale", torch.tensor(0.05),
                       constraint=constraints.positive)
    pyro.sample("x", dist.Normal(0.0, scale))

while the following high-variance guide is very likely to lead to problems:

def bad_guide():
    scale = pyro.param("scale", torch.tensor(12345.6),
                       constraint=constraints.positive)
    pyro.sample("x", dist.Normal(0.0, scale))

Note that the initial variance of autoguides can be controlled with the init_scale argument, see e.g. here for AutoNormal.

8. Explore trade-offs controlled by num_particles, mini-batch size, etc.

Optimization can be difficult if your ELBO exhibits large variance. One way you can try to mitigate this issue is to increase the number of particles used to compute each stochastic ELBO estimate:

elbo = pyro.infer.Trace_ELBO(num_particles=10,
                             vectorize_particles=True)

(Note that to use vectorized_particles=True you need to ensure your model and guide are properly vectorized; see the tensor shapes tutorial for best practices.) This results in lower variance gradients at the cost of more compute. If you are doing data subsampling, the mini-batch size offers a similar trade-off: larger mini-batch sizes reduce the variance at the cost of more compute. Although what’s best is problem dependent, it’s usually worth taking more gradient steps with fewer particles than fewer gradient steps with more particles. An important caveat to this is when you’re running on a GPU, in which case (at least for some models) the cost of increasing num_particles or your mini-batch size may be sublinear, in which case increasing num_particles is likely more attractive.

9. Use TraceMeanField_ELBO if applicable

The basic ELBO implementation in Pyro, Trace_ELBO, uses stochastic samples to estimate the KL divergence term. When analytic KL divergences are available, you may be able to lower ELBO variance by using analytic KL divergences instead. This functionality is provided by TraceMeanField_ELBO.

10. Consider normalizing your ELBO

By default Pyro computes a un-normalized ELBO, i.e. it computes the quantity that is a lower bound to the log evidence computed on the full set of data that is being conditioned on. For large datasets this can be a number of large magnitude. Since computers use finite precision (e.g. 32-bit floats) to do arithmetic, large numbers can be problematic for numerical stability, since they can lead to loss of precision, under/overflow, etc. For this reason it can be helpful in many cases to normalize your ELBO so that it is roughly order one. This can also be helpful for getting a rough feeling for how good your ELBO numbers are. For example if we have \(N\) datapoints of dimension \(D\) (e.g. \(N\) real-valued vectors of dimension \(D\)) then we generally expect a reasonably well optimized ELBO to be order \(N \times D\). Thus if we renormalize our ELBO by a factor of \(N \times D\) we expect an ELBO of order one. While this is just a rough rule-of-thumb, if we use this kind of normalization and obtain ELBO values like \(-123.4\) or \(1234.5\) then something is probably wrong: perhaps our model is terribly mis-specified; perhaps our initialization is catastrophically bad, etc. For details on how you can scale your ELBO by a normalization constant see this tutorial.

11. Pay attention to scales

Scales of numbers matter. They matter for at least two important reasons: i) scales can make or break a particular initialization scheme; ii) as discussed in the previous section, scales can have an impact on numerical precision and stability.

To make this concrete suppose you are doing linear regression, i.e. you’re learning a linear map of the form \(Y = W @ X\). Often the data comes with particular units. For example some of the components of the covariate \(X\) may be in units of dollars (e.g. house prices), while others may be in units of density (e.g. residents per square mile). Perhaps the the first covariate has typical values like \(10^5\), while the second covariate has typical values like \(10^2\). You should always pay attention when you encounter numbers that range across many orders of magnitude. In many cases it makes sense to normalize things so that they are order unity. For example you might measure house prices in units of $100,000.

These sorts of data transformations can have a number of benefits for downstream modeling and inference. For example if you’ve normalized all of your covariates appropriately, it may be reasonable to set a simple isotropic prior on your weights

pyro.sample("W", dist.Normal(torch.zeros(2), torch.ones(2)))

instead of having to specify different prior covariances for different covariates

prior_scale = torch.tensor([1.0e-5, 1.0e-2])
pyro.sample("W", dist.Normal(torch.zeros(2), prior_scale))

There are other benefits too. It now becomes easier to initialize appropriate parameters for your guide. It is also now much more likely that the default initializations used by a pyro.infer.AutoGuide will work for your problem.

12. Keep validation enabled

By default Pyro enables validation logic that can be helpful in debugging models and guides. For example, validation logic will inform you when distribution parameters become invalid. Unless you have good reason to do otherwise, keep the validation logic enabled. Once you’re satisfied with a model and inference procedure, you may wish to disable validation using pyro.enable_validation.

Similarly in the context of ELBOs it is a good idea to set

strict_enumeration_warning=True

when you are enumerating discrete latent variables.

13. Tensor shape errors

If you’re running into tensor shape errors please make sure you have carefully read the corresponding tutorial.

14. Enumerate discrete latent variables if possible

If your model contains discrete latent variables it may make sense to enumerate them out exactly, since this can significantly reduce ELBO variance. For more discussion see the corresponding tutorial.

15. Some complex models can benefit from KL annealing

The particular form of the ELBO encodes a trade-off between model fit via the expected log likelihood term and a prior regularization term via the KL divergence. In some cases the KL divergence can act as a barrier that makes it difficult to find good optima. In these cases it can help to anneal the relevant strength of the KL divergence term during optimization. For further discussion see the deep markov model tutorial.

16. Consider clipping gradients or constraining parameters defensively

Certain parameters in your model or guide may control distribution parameters that can be sensitive to numerical issues. For example, the concentration and rate parameters that defines a Gamma distribution may exhibit such sensitivity. In these cases it may make sense to clip gradients or constrain parameters defensively. See this code snippet for an example of gradient clipping. For a simple example of “defensive” parameter constraints consider the concentration parameter of a Gamma distribution. This parameter must be positive: concentration > 0. If we want to ensure that concentration stays away from zero we can use a param statement with an appropriate constraint:

from pyro.distributions import constraints

concentration = pyro.param("concentration", torch.tensor(0.5),
                           constraints.greater_than(0.001))

These kinds of tricks can help ensure that your models and guides stay away from numerically dangerous parts of parameter space.