Example: Toy Mixture Model With Discrete EnumerationΒΆ

View toy_mixture_model_discrete_enumeration.py on github

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
 A toy mixture model to provide a simple example for implementing discrete enumeration.

 (A) -> [B] -> (C)

 A is an observed Bernoulli variable with Beta prior.
 B is a hidden variable which is a mixture of two Bernoulli distributions (with Beta priors),
 chosen by A being true or false.
 C is observed, and like B, is a mixture of two Bernoulli distributions (with Beta priors),
 chosen by B being true or false.
 There is a plate over the three variables for n independent observations of data.

 Because B is hidden and discrete we wish to marginalize it out of the model.
 This is done by:
    1) marking the model method with `@pyro.infer.config_enumerate`
    2) marking the B sample site in the model with `infer={"enumerate": "parallel"}`
    3) passing `pyro.infer.SVI` the `pyro.infer.TraceEnum_ELBO` loss function
"""

import argparse

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions import constraints
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.beta import Beta
from tqdm import tqdm

import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.optim
from pyro.ops.indexing import Vindex


def main(args):
    num_obs = args.num_obs
    num_steps = args.num_steps
    prior, CPDs, data = generate_data(num_obs)
    posterior_params = train(prior, data, num_steps, num_obs)
    evaluate(CPDs, posterior_params)


def generate_data(num_obs):
    # domain = [False, True]
    prior = {'A': torch.tensor([1., 10.]),
             'B': torch.tensor([[10., 1.],
                                [1., 10.]]),
             'C': torch.tensor([[10., 1.],
                                [1., 10.]])}
    CPDs = {'p_A': Beta(prior['A'][0], prior['A'][1]).sample(),
            'p_B': Beta(prior['B'][:, 0], prior['B'][:, 1]).sample(),
            'p_C': Beta(prior['C'][:, 0], prior['C'][:, 1]).sample(),
            }
    data = {'A': Bernoulli(torch.ones(num_obs) * CPDs['p_A']).sample()}
    data['B'] = Bernoulli(torch.gather(CPDs['p_B'], 0, data['A'].type(torch.long))).sample()
    data['C'] = Bernoulli(torch.gather(CPDs['p_C'], 0, data['B'].type(torch.long))).sample()
    return prior, CPDs, data


@pyro.infer.config_enumerate
def model(prior, obs, num_obs):
    p_A = pyro.sample('p_A', dist.Beta(1, 1))
    p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    p_C = pyro.sample('p_C', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    with pyro.plate('data_plate', num_obs):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(num_obs)), obs=obs['A'])
        # Vindex used to ensure proper indexing into the enumerated sample sites
        B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"})
        pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C'])


def guide(prior, obs, num_obs):
    a = pyro.param('a', prior['A'], constraint=constraints.positive)
    pyro.sample('p_A', dist.Beta(a[0], a[1]))
    b = pyro.param('b', prior['B'], constraint=constraints.positive)
    pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
    c = pyro.param('c', prior['C'], constraint=constraints.positive)
    pyro.sample('p_C', dist.Beta(c[:, 0], c[:, 1]).to_event(1))


def train(prior, data, num_steps, num_obs):
    pyro.enable_validation(True)
    pyro.clear_param_store()
    # max_plate_nesting = 1 because there is a single plate in the model
    loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
    svi = pyro.infer.SVI(model,
                         guide,
                         pyro.optim.Adam({'lr': .01}),
                         loss=loss_func
                         )
    losses = []
    for _ in tqdm(range(num_steps)):
        loss = svi.step(prior, data, num_obs)
        losses.append(loss)
    plt.figure()
    plt.plot(losses)
    plt.show()
    posterior_params = {k: np.array(v.data) for k, v in pyro.get_param_store().items()}
    posterior_params['a'] = posterior_params['a'][None, :]  # reshape to same as other variables
    return posterior_params


def evaluate(CPDs, posterior_params):
    true_p_A, pred_p_A = get_true_pred_CPDs(CPDs['p_A'], posterior_params['a'])
    true_p_B, pred_p_B = get_true_pred_CPDs(CPDs['p_B'], posterior_params['b'])
    true_p_C, pred_p_C = get_true_pred_CPDs(CPDs['p_C'], posterior_params['c'])
    print('\np_A = True')
    print('actual:   ', true_p_A)
    print('predicted:', pred_p_A)
    print('\np_B = True | A = False/True')
    print('actual:   ', true_p_B)
    print('predicted:', pred_p_B)
    print('\np_C = True | B = False/True')
    print('actual:   ', true_p_C)
    print('predicted:', pred_p_C)


def get_true_pred_CPDs(CPD, posterior_param):
    true_p = CPD.numpy()
    pred_p = posterior_param[:, 0]/np.sum(posterior_param, axis=1)
    return true_p, pred_p


if __name__ == "__main__":
    assert pyro.__version__.startswith('1.5.0')
    parser = argparse.ArgumentParser(description="Toy mixture model")
    parser.add_argument("-n", "--num-steps", default=4000, type=int)
    parser.add_argument("-o", "--num-obs", default=10000, type=int)
    args = parser.parse_args()
    main(args)