Customizing SVI objectives and training loops¶
Pyro provides support for various optimization-based approaches to Bayesian inference, with Trace_ELBO
serving as the basic implementation of SVI (stochastic variational inference). See the docs for more information on the various SVI implementations and SVI tutorials I, II, and
III for background on SVI.
In this tutorial we show how advanced users can modify and/or augment the variational objectives (alternatively: loss functions) and the training step implementation provided by Pyro to support special use cases.
Basic SVI Usage¶
We first review the basic usage pattern of SVI
objects in Pyro. We assume that the user has defined a model
and a guide
. The user then creates an optimizer and an SVI
object:
optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())
Gradient steps can then be taken with a call to svi.step(...)
. The arguments to step()
are then passed to model
and guide
.
A Lower-Level Pattern¶
The nice thing about the above pattern is that it allows Pyro to take care of various details for us, for example:
pyro.optim.Adam
dynamically creates a newtorch.optim.Adam
optimizer whenever a new parameter is encounteredSVI.step()
zeros gradients between gradient steps
If we want more control, we can directly manipulate the differentiable loss method of the various ELBO
classes. For example, this optimization loop:
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())
for i in range(n_iter):
loss = svi.step(X_train, y_train)
is equivalent to this low-level pattern:
loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)
with pyro.poutine.trace(param_only=True) as param_capture:
loss = loss_fn(model, guide)
params = set(site["value"].unconstrained()
for site in param_capture.trace.nodes.values())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))
for i in range(n_iter):
# compute loss
loss = loss_fn(model, guide)
loss.backward()
# take a step and zero the parameter gradients
optimizer.step()
optimizer.zero_grad()
Example: Custom Regularizer¶
Suppose we want to add a custom regularization term to the SVI loss. Using the above usage pattern, this is easy to do. First we define our regularizer:
def my_custom_L2_regularizer(my_parameters):
reg_loss = 0.0
for param in my_parameters:
reg_loss = reg_loss + param.pow(2.0).sum()
return reg_loss
Then the only change we need to make is:
- loss = loss_fn(model, guide)
+ loss = loss_fn(model, guide) + my_custom_L2_regularizer(my_parameters)
Example: Clipping Gradients¶
For some models the loss gradient can explode during training, leading to overflow and NaN
values. One way to protect against this is with gradient clipping. The optimizers in pyro.optim
take an optional dictionary of clip_args
which allows clipping either the gradient norm or the gradient value to fall within the given limit.
To change the basic example above:
- optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})
+ optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0})
Further variants of gradient clipping can also be implemented manually by modifying the low-level pattern described above.
Example: Scaling the Loss¶
Depending on the optimization algorithm, the scale of the loss may or not matter. Suppose we want to scale our loss function by the number of datapoints before we differentiate it. This is easily done:
- loss = loss_fn(model, guide)
+ loss = loss_fn(model, guide) / N_data
Note that in the case of SVI, where each term in the loss function is a log probability from the model or guide, this same effect can be achieved using poutine.scale. For example we can use the poutine.scale
decorator to scale both the model and guide:
@poutine.scale(scale=1.0/N_data)
def model(...):
pass
@poutine.scale(scale=1.0/N_data)
def guide(...):
pass
Example: Beta VAE¶
We can also use poutine.scale to construct non-standard ELBO variational objectives in which, for example, the KL divergence is scaled differently relative to the expected log likelihood. In particular for the Beta VAE the KL divergence is scaled by a factor beta
:
def model(data, beta=0.5):
z_loc, z_scale = ...
with pyro.poutine.scale(scale=beta)
z = pyro.sample("z", dist.Normal(z_loc, z_scale))
pyro.sample("obs", dist.Bernoulli(...), obs=data)
def guide(data, beta=0.5):
with pyro.poutine.scale(scale=beta)
z_loc, z_scale = ...
z = pyro.sample("z", dist.Normal(z_loc, z_scale))
With this choice of model and guide the log densities corresponding to the latent variable z
that enter into constructing the variational objective via
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())
will be scaled by a factor of beta
, resulting in a KL divergence that is likewise scaled by beta
.
Example: Mixing Optimizers¶
The various optimizers in pyro.optim
allow the user to specify optimization settings (e.g. learning rates) on a per-parameter basis. But what if we want to use different optimization algorithms for different parameters? We can do this using Pyro’s MultiOptimizer
(see below), but we can also achieve the same thing if we directly manipulate differentiable_loss
:
adam = torch.optim.Adam(adam_parameters, {"lr": 0.001, "betas": (0.90, 0.999)})
sgd = torch.optim.SGD(sgd_parameters, {"lr": 0.0001})
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
# compute loss
loss = loss_fn(model, guide)
loss.backward()
# take a step and zero the parameter gradients
adam.step()
sgd.step()
adam.zero_grad()
sgd.zero_grad()
For completeness, we also show how we can do the same thing using MultiOptimizer, which allows us to combine multiple Pyro optimizers. Note that since MultiOptimizer
uses torch.autograd.grad
under the hood (instead of torch.Tensor.backward()
), it has a slightly different interface; in particular the step()
method also takes parameters as inputs.
def model():
pyro.param('a', ...)
pyro.param('b', ...)
...
adam = pyro.optim.Adam({'lr': 0.1})
sgd = pyro.optim.SGD({'lr': 0.01})
optim = MixedMultiOptimizer([(['a'], adam), (['b'], sgd)])
with pyro.poutine.trace(param_only=True) as param_capture:
loss = elbo.differentiable_loss(model, guide)
params = {'a': pyro.param('a'), 'b': pyro.param('b')}
optim.step(loss, params)
Example: Custom ELBO¶
In the previous three examples we bypassed creating a SVI
object and directly manipulated the differentiable loss function provided by an ELBO
implementation. Another thing we can do is create custom ELBO
implementations and pass those into the SVI
machinery. For example, a simplified version of a Trace_ELBO
loss function might look as follows:
# note that simple_elbo takes a model, a guide, and their respective arguments as inputs
def simple_elbo(model, guide, *args, **kwargs):
# run the guide and trace its execution
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
# run the model and replay it against the samples from the guide
model_trace = poutine.trace(
poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
# construct the elbo loss function
return -1*(model_trace.log_prob_sum() - guide_trace.log_prob_sum())
svi = SVI(model, guide, optim, loss=simple_elbo)
Note that this is basically what the elbo
implementation in “mini-pyro” looks like.
Example: KL Annealing¶
In the Deep Markov Model Tutorial the ELBO variational objective is modified during training. In particular the various KL-divergence terms between latent random variables are scaled downward (i.e. annealed) relative to the log probabilities of the observed data. In the tutorial this is accomplished using poutine.scale
. We can accomplish the same thing by defining a custom loss function. This latter option is not a very elegant pattern but we include it
anyway to show the flexibility we have at our disposal.
def simple_elbo_kl_annealing(model, guide, *args, **kwargs):
# get the annealing factor and latents to anneal from the keyword
# arguments passed to the model and guide
annealing_factor = kwargs.pop('annealing_factor', 1.0)
latents_to_anneal = kwargs.pop('latents_to_anneal', [])
# run the guide and replay the model against the guide
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
elbo = 0.0
# loop through all the sample sites in the model and guide trace and
# construct the loss; note that we scale all the log probabilities of
# samples sites in `latents_to_anneal` by the factor `annealing_factor`
for site in model_trace.values():
if site["type"] == "sample":
factor = annealing_factor if site["name"] in latents_to_anneal else 1.0
elbo = elbo + factor * site["fn"].log_prob(site["value"]).sum()
for site in guide_trace.values():
if site["type"] == "sample":
factor = annealing_factor if site["name"] in latents_to_anneal else 1.0
elbo = elbo - factor * site["fn"].log_prob(site["value"]).sum()
return -elbo
svi = SVI(model, guide, optim, loss=simple_elbo_kl_annealing)
svi.step(other_args, annealing_factor=0.2, latents_to_anneal=["my_latent"])