Writing guides using EasyGuide¶
This tutorial describes the pyro.contrib.easyguide module. This tutorial assumes the reader is already familiar with SVI and tensor shapes.
Summary¶
For simple black-box guides, try using components in pyro.infer.autoguide.
For more complex guides, try using components in pyro.contrib.easyguide.
Decorate with
@easy_guide(model)
.Select multiple model sites using
group = self.group(match="my_regex")
.Guide a group of sites by a single distribution using
group.sample(...)
.Inspect concatenated group shape using
group.batch_shape
,group.event_shape
, etc.Use
self.plate(...)
instead ofpyro.plate(...)
.To be compatible with subsampling, pass the
event_dim
arg topyro.param(...)
.To MAP estimate model site “foo”, use
foo = self.map_estimate("foo")
.
Table of contents¶
[ ]:
import os
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.easyguide import easy_guide
from pyro.optim import Adam
from torch.distributions import constraints
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')
Modeling time series data¶
Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function.
[ ]:
def model(batch, subsample, full_size):
batch = list(batch)
num_time_steps = len(batch)
drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
with pyro.plate("data", full_size, subsample=subsample):
z = 0.
for t in range(num_time_steps):
z = pyro.sample("state_{}".format(t),
dist.Normal(z, drift))
batch[t] = pyro.sample("obs_{}".format(t),
dist.Bernoulli(logits=z),
obs=batch[t])
return torch.stack(batch)
Let’s generate some data directly from the model.
[ ]:
full_size = 100
num_time_steps = 7
pyro.set_rng_seed(123456789)
data = model([None] * num_time_steps, torch.arange(full_size), full_size)
assert data.shape == (num_time_steps, full_size)
Writing a guide without EasyGuide¶
Consider a possible guide for this model where we point-estimate the drift
parameter using a Delta
distribution, and then model local time series using shared uncertainty but local means, using a LowRankMultivariateNormal
distribution. There is a single global sample site which we can model with a param
and sample
statement. Then we sample a global pair of uncertainty parameters cov_diag
and cov_factor
. Next we sample a local loc
parameter using
pyro.param(..., event_dim=...)
and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-Delta
s pattern is quite common.
[ ]:
rank = 3
def guide(batch, subsample, full_size):
num_time_steps, batch_size = batch.shape
# MAP estimate the drift.
drift_loc = pyro.param("drift_loc", lambda: torch.tensor(0.1),
constraint=constraints.positive)
pyro.sample("drift", dist.Delta(drift_loc))
# Model local states using shared uncertainty + local mean.
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full((num_time_steps,), 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(num_time_steps, rank) * 0.01)
with pyro.plate("data", full_size, subsample=subsample):
# Sample local mean.
loc = pyro.param("state_loc",
lambda: torch.full((full_size, num_time_steps), 0.5),
event_dim=1)
states = pyro.sample("states",
dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),
infer={"is_auxiliary": True})
# Unpack the joint states into one sample site per time step.
for t in range(num_time_steps):
pyro.sample("state_{}".format(t), dist.Delta(states[:, t]))
Let’s train using SVI and Trace_ELBO, manually batching data into small minibatches.
[ ]:
def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):
full_size = data.size(-1)
pyro.get_param_store().clear()
pyro.set_rng_seed(123456789)
svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
for epoch in range(num_epochs):
pos = 0
losses = []
while pos < full_size:
subsample = torch.arange(pos, pos + batch_size)
batch = data[:, pos:pos + batch_size]
pos += batch_size
losses.append(svi.step(batch, subsample, full_size=full_size))
epoch_loss = sum(losses) / len(losses)
if epoch % 10 == 0:
print("epoch {} loss = {}".format(epoch, epoch_loss / data.numel()))
[ ]:
train(guide)
Using EasyGuide¶
Now let’s simplify using the @easy_guide
decorator. Our modifications are: 1. Decorate with @easy_guide
and add self
to args. 2. Replace the Delta
guide for drift with a simple map_estimate()
. 3. Select a group
of model sites and read their concatenated event_shape
. 4. Replace the auxiliary site and Delta
slices with a single group.sample()
.
[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
# MAP estimate the drift.
self.map_estimate("drift")
# Model local states using shared uncertainty + local mean.
group = self.group(match="state_[0-9]*") # Selects all local variables.
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full(group.event_shape, 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
with self.plate("data", full_size, subsample=subsample):
# Sample local mean.
loc = pyro.param("state_loc",
lambda: torch.full((full_size,) + group.event_shape, 0.5),
event_dim=1)
# Automatically sample the joint latent, then unpack and replay model sites.
group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
Note we’ve used group.event_shape
to determine the total flattened concatenated shape of all matched sites in the group.
[ ]:
train(guide)
Amortized guides¶
EasyGuide
also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let’s modify the last guide to predict the latent loc
as an affine function of observed data, rather than memorizing each data point’s latent variable. This amortized guide is more useful in practice because it can handle new data.
[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
num_time_steps, batch_size = batch.shape
self.map_estimate("drift")
group = self.group(match="state_[0-9]*")
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full(group.event_shape, 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
# Predict latent propensity as an affine function of observed data.
if not hasattr(self, "nn"):
self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())
self.nn.weight.data.fill_(1.0 / num_time_steps)
self.nn.bias.data.fill_(-0.5)
pyro.module("state_nn", self.nn)
with self.plate("data", full_size, subsample=subsample):
loc = self.nn(batch.t())
group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
[ ]:
train(guide)
[ ]: