Example: reducing boilerplate with pyro.contrib.autoname
¶
Mixture¶
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.contrib.autoname import named
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam
# This is a simple gaussian mixture model.
#
# The example demonstrates how to pass named.Objects() from a global model to
# a local model implemented as a helper function.
def model(data, k):
latent = named.Object("latent")
# Create parameters for a Gaussian mixture model.
latent.probs.param_(torch.ones(k) / k, constraint=constraints.simplex)
latent.locs.param_(torch.zeros(k))
latent.scales.param_(torch.ones(k), constraint=constraints.positive)
# Observe all the data. We pass a local latent in to the local_model.
latent.local = named.List()
for x in data:
local_model(latent.local.add(), latent.probs, latent.locs, latent.scales, obs=x)
def local_model(latent, ps, locs, scales, obs=None):
i = latent.id.sample_(dist.Categorical(ps))
return latent.x.sample_(dist.Normal(locs[i], scales[i]), obs=obs)
def guide(data, k):
latent = named.Object("latent")
latent.local = named.List()
for x in data:
# We pass a local latent in to the local_guide.
local_guide(latent.local.add(), k)
def local_guide(latent, k):
# The local guide simply guesses category assignments.
latent.probs.param_(torch.ones(k) / k, constraint=constraints.positive)
latent.id.sample_(dist.Categorical(latent.probs))
def main(args):
pyro.set_rng_seed(0)
optim = Adam({"lr": 0.1})
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
inference = SVI(model, guide, optim, loss=elbo)
data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
k = 2
print("Step\tLoss")
loss = 0.0
for step in range(args.num_epochs):
if step and step % 10 == 0:
print("{}\t{:0.5g}".format(step, loss))
loss = 0.0
loss += inference.step(data, k=k)
print("Parameters:")
for name, value in sorted(pyro.get_param_store().items()):
print("{} = {}".format(name, value.detach().cpu().numpy()))
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
parser.add_argument("--jit", action="store_true")
args = parser.parse_args()
main(args)
Scoping¶
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
import pyro.optim
from pyro.contrib.autoname import scope
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
def model(K, data):
# Global parameters.
weights = pyro.param("weights", torch.ones(K) / K, constraint=constraints.simplex)
locs = pyro.param("locs", 10 * torch.randn(K))
scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
with pyro.plate("data"):
return local_model(weights, locs, scale, data)
@scope(prefix="local")
def local_model(weights, locs, scale, data):
assignment = pyro.sample(
"assignment", dist.Categorical(weights).expand_by([len(data)])
)
return pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)
def guide(K, data):
assignment_probs = pyro.param(
"assignment_probs",
torch.ones(len(data), K) / K,
constraint=constraints.unit_interval,
)
with pyro.plate("data"):
return local_guide(assignment_probs)
@scope(prefix="local")
def local_guide(probs):
return pyro.sample("assignment", dist.Categorical(probs))
def main(args):
pyro.set_rng_seed(0)
pyro.clear_param_store()
K = 2
data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
optim = pyro.optim.Adam({"lr": 0.1})
inference = SVI(
model, config_enumerate(guide), optim, loss=TraceEnum_ELBO(max_plate_nesting=1)
)
print("Step\tLoss")
loss = 0.0
for step in range(args.num_epochs):
if step and step % 10 == 0:
print("{}\t{:0.5g}".format(step, loss))
loss = 0.0
loss += inference.step(K, data)
print("Parameters:")
for name, value in sorted(pyro.get_param_store().items()):
print("{} = {}".format(name, value.detach().cpu().numpy()))
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=200, type=int)
args = parser.parse_args()
main(args)
Autoname and tree-structured data¶
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.contrib.autoname import named
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# This is a linear mixed-effects model over arbitrary json-like data.
# Data can be a number, a list of data, or a dict with data values.
#
# The goal is to learn a mean field approximation to the posterior
# values z, parameterized by parameters post_loc and post_scale.
#
# Notice that the named.Objects allow for modularity that fits well
# with the recursive model and guide functions.
def model(data):
latent = named.Object("latent")
latent.z.sample_(dist.Normal(0.0, 1.0))
model_recurse(data, latent)
def model_recurse(data, latent):
if torch.is_tensor(data):
latent.x.sample_(dist.Normal(latent.z, 1.0), obs=data)
elif isinstance(data, list):
latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
latent.list = named.List()
for data_i in data:
latent_i = latent.list.add()
latent_i.z.sample_(dist.Normal(latent.z, latent.prior_scale))
model_recurse(data_i, latent_i)
elif isinstance(data, dict):
latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
latent.dict = named.Dict()
for key, value in data.items():
latent.dict[key].z.sample_(dist.Normal(latent.z, latent.prior_scale))
model_recurse(value, latent.dict[key])
else:
raise TypeError("Unsupported type {}".format(type(data)))
def guide(data):
guide_recurse(data, named.Object("latent"))
def guide_recurse(data, latent):
latent.post_loc.param_(torch.tensor(0.0))
latent.post_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
latent.z.sample_(dist.Normal(latent.post_loc, latent.post_scale))
if torch.is_tensor(data):
pass
elif isinstance(data, list):
latent.list = named.List()
for datum in data:
guide_recurse(datum, latent.list.add())
elif isinstance(data, dict):
latent.dict = named.Dict()
for key, value in data.items():
guide_recurse(value, latent.dict[key])
else:
raise TypeError("Unsupported type {}".format(type(data)))
def main(args):
pyro.set_rng_seed(0)
optim = Adam({"lr": 0.1})
inference = SVI(model, guide, optim, loss=Trace_ELBO())
# Data is an arbitrary json-like structure with tensors at leaves.
one = torch.tensor(1.0)
data = {
"foo": one,
"bar": [0 * one, 1 * one, 2 * one],
"baz": {
"noun": {
"concrete": 4 * one,
"abstract": 6 * one,
},
"verb": 2 * one,
},
}
print("Step\tLoss")
loss = 0.0
for step in range(args.num_epochs):
loss += inference.step(data)
if step and step % 10 == 0:
print("{}\t{:0.5g}".format(step, loss))
loss = 0.0
print("Parameters:")
for name, value in sorted(pyro.get_param_store().items()):
print("{} = {}".format(name, value.detach().cpu().numpy()))
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument("-n", "--num-epochs", default=100, type=int)
args = parser.parse_args()
main(args)