Example: reducing boilerplate with pyro.contrib.autoname

View on github

Mixture

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import argparse

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.contrib.autoname import named
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam

# This is a simple gaussian mixture model.
#
# The example demonstrates how to pass named.Objects() from a global model to
# a local model implemented as a helper function.


def model(data, k):
    latent = named.Object("latent")

    # Create parameters for a Gaussian mixture model.
    latent.probs.param_(torch.ones(k) / k, constraint=constraints.simplex)
    latent.locs.param_(torch.zeros(k))
    latent.scales.param_(torch.ones(k), constraint=constraints.positive)

    # Observe all the data. We pass a local latent in to the local_model.
    latent.local = named.List()
    for x in data:
        local_model(latent.local.add(), latent.probs, latent.locs, latent.scales, obs=x)


def local_model(latent, ps, locs, scales, obs=None):
    i = latent.id.sample_(dist.Categorical(ps))
    return latent.x.sample_(dist.Normal(locs[i], scales[i]), obs=obs)


def guide(data, k):
    latent = named.Object("latent")
    latent.local = named.List()
    for x in data:
        # We pass a local latent in to the local_guide.
        local_guide(latent.local.add(), k)


def local_guide(latent, k):
    # The local guide simply guesses category assignments.
    latent.probs.param_(torch.ones(k) / k, constraint=constraints.positive)
    latent.id.sample_(dist.Categorical(latent.probs))


def main(args):
    pyro.set_rng_seed(0)

    optim = Adam({"lr": 0.1})
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    inference = SVI(model, guide, optim, loss=elbo)
    data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
    k = 2

    print("Step\tLoss")
    loss = 0.0
    for step in range(args.num_epochs):
        if step and step % 10 == 0:
            print("{}\t{:0.5g}".format(step, loss))
            loss = 0.0
        loss += inference.step(data, k=k)

    print("Parameters:")
    for name, value in sorted(pyro.get_param_store().items()):
        print("{} = {}".format(name, value.detach().cpu().numpy()))


if __name__ == "__main__":
    assert pyro.__version__.startswith("1.9.1")
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument("-n", "--num-epochs", default=200, type=int)
    parser.add_argument("--jit", action="store_true")
    args = parser.parse_args()
    main(args)

Scoping

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import argparse

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim
from pyro.contrib.autoname import scope
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate


def model(K, data):
    # Global parameters.
    weights = pyro.param("weights", torch.ones(K) / K, constraint=constraints.simplex)
    locs = pyro.param("locs", 10 * torch.randn(K))
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)

    with pyro.plate("data"):
        return local_model(weights, locs, scale, data)


@scope(prefix="local")
def local_model(weights, locs, scale, data):
    assignment = pyro.sample(
        "assignment", dist.Categorical(weights).expand_by([len(data)])
    )
    return pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)


def guide(K, data):
    assignment_probs = pyro.param(
        "assignment_probs",
        torch.ones(len(data), K) / K,
        constraint=constraints.unit_interval,
    )
    with pyro.plate("data"):
        return local_guide(assignment_probs)


@scope(prefix="local")
def local_guide(probs):
    return pyro.sample("assignment", dist.Categorical(probs))


def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    K = 2

    data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
    optim = pyro.optim.Adam({"lr": 0.1})
    inference = SVI(
        model, config_enumerate(guide), optim, loss=TraceEnum_ELBO(max_plate_nesting=1)
    )

    print("Step\tLoss")
    loss = 0.0
    for step in range(args.num_epochs):
        if step and step % 10 == 0:
            print("{}\t{:0.5g}".format(step, loss))
            loss = 0.0
        loss += inference.step(K, data)

    print("Parameters:")
    for name, value in sorted(pyro.get_param_store().items()):
        print("{} = {}".format(name, value.detach().cpu().numpy()))


if __name__ == "__main__":
    assert pyro.__version__.startswith("1.9.1")
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument("-n", "--num-epochs", default=200, type=int)
    args = parser.parse_args()
    main(args)

Autoname and tree-structured data

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import argparse

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.contrib.autoname import named
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# This is a linear mixed-effects model over arbitrary json-like data.
# Data can be a number, a list of data, or a dict with data values.
#
# The goal is to learn a mean field approximation to the posterior
# values z, parameterized by parameters post_loc and post_scale.
#
# Notice that the named.Objects allow for modularity that fits well
# with the recursive model and guide functions.


def model(data):
    latent = named.Object("latent")
    latent.z.sample_(dist.Normal(0.0, 1.0))
    model_recurse(data, latent)


def model_recurse(data, latent):
    if torch.is_tensor(data):
        latent.x.sample_(dist.Normal(latent.z, 1.0), obs=data)
    elif isinstance(data, list):
        latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
        latent.list = named.List()
        for data_i in data:
            latent_i = latent.list.add()
            latent_i.z.sample_(dist.Normal(latent.z, latent.prior_scale))
            model_recurse(data_i, latent_i)
    elif isinstance(data, dict):
        latent.prior_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
        latent.dict = named.Dict()
        for key, value in data.items():
            latent.dict[key].z.sample_(dist.Normal(latent.z, latent.prior_scale))
            model_recurse(value, latent.dict[key])
    else:
        raise TypeError("Unsupported type {}".format(type(data)))


def guide(data):
    guide_recurse(data, named.Object("latent"))


def guide_recurse(data, latent):
    latent.post_loc.param_(torch.tensor(0.0))
    latent.post_scale.param_(torch.tensor(1.0), constraint=constraints.positive)
    latent.z.sample_(dist.Normal(latent.post_loc, latent.post_scale))
    if torch.is_tensor(data):
        pass
    elif isinstance(data, list):
        latent.list = named.List()
        for datum in data:
            guide_recurse(datum, latent.list.add())
    elif isinstance(data, dict):
        latent.dict = named.Dict()
        for key, value in data.items():
            guide_recurse(value, latent.dict[key])
    else:
        raise TypeError("Unsupported type {}".format(type(data)))


def main(args):
    pyro.set_rng_seed(0)

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss=Trace_ELBO())

    # Data is an arbitrary json-like structure with tensors at leaves.
    one = torch.tensor(1.0)
    data = {
        "foo": one,
        "bar": [0 * one, 1 * one, 2 * one],
        "baz": {
            "noun": {
                "concrete": 4 * one,
                "abstract": 6 * one,
            },
            "verb": 2 * one,
        },
    }

    print("Step\tLoss")
    loss = 0.0
    for step in range(args.num_epochs):
        loss += inference.step(data)
        if step and step % 10 == 0:
            print("{}\t{:0.5g}".format(step, loss))
            loss = 0.0

    print("Parameters:")
    for name, value in sorted(pyro.get_param_store().items()):
        print("{} = {}".format(name, value.detach().cpu().numpy()))


if __name__ == "__main__":
    assert pyro.__version__.startswith("1.9.1")
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument("-n", "--num-epochs", default=100, type=int)
    args = parser.parse_args()
    main(args)