Gaussian Mixture Model

This is a brief tutorial on training mixture models in Pyro. We’ll focus on the mechanics of config_enumerate() and setting up mixture weights. To simplify matters, we’ll train a trivial 1-D Gaussian model on a tiny 5-point dataset.

In [1]:
from __future__ import print_function
import os
from collections import defaultdict
import numpy as np
import scipy.stats
import torch
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)

Dataset

Here is our tiny dataset. It has five points.

In [2]:
data = torch.tensor([0., 1., 10., 11., 12.])

Maximum likelihood approach

Let’s start by optimizing model parameters weights, locs, and scale, rather than treating them as random variables with priors. Our model will learn global mixture weights, the location of each mixture component, and a shared scale that is common to both components. Our guide will learn soft assignment weights of each point.

Note that none of our parameters have priors. In this Maximum Likelihood approach we can embed our parameters directly in the model rather than the guide. This is equivalent to adding them in the guide as pyro.sample(..., dist.Delta(...)) sites and using a uniform prior in the model.

In [3]:
K = 2  # Fixed number of components.

def model(data):
    # Global parameters.
    weights = pyro.param('weights', torch.ones(K) / K, constraint=constraints.simplex)
    locs = pyro.param('locs', 10 * torch.randn(K))
    scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive)

    with pyro.iarange('data'):
        # Local variables.
        assignment = pyro.sample('assignment',
                                 dist.Categorical(weights).expand_by([len(data)]))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

def guide(data):
    with pyro.iarange('data'):
        # Local parameters.
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment', dist.Categorical(assignment_probs))

To run inference with this (model,guide) pair, we use Pyro’s config_enumerate() function to enumerate over all assignments in each iteration. Since we’ve wrapped the batched Categorical assignments in a pyro.iarange indepencence context, this enumeration can happen in parallel: we enumerate only 2 possibilites, rather than 2**len(data) = 32. Finally, to use the parallel version of enumeration, we inform pyro that we’re only using a single iarange via max_iarange_nesting=1; this lets Pyro know that we’re using the rightmost dimension iarange and letting use any other dimension for parallelization.

In [4]:
optim = pyro.optim.Adam({'lr': 0.2, 'betas': [0.9, 0.99]})
inference = SVI(model, config_enumerate(guide, 'parallel'), optim,
                loss=TraceEnum_ELBO(max_iarange_nesting=1))

During training, we’ll collect both losses and gradient norms to monitor convergence. We can do this using PyTorch’s .register_hook() method.

In [5]:
pyro.set_rng_seed(1)      # Set seed to make results reproducible.
pyro.clear_param_store()  # Clear stale param values.

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
inference.loss(model, guide, data)  # Initializes param store.
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(500 if not smoke_test else 2):
    loss = inference.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')

...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
...................................................................................................
In [6]:
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');
_images/gmm_10_0.png
In [7]:
pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');
_images/gmm_11_0.png

Here are the learned parameters:

In [8]:
weights = pyro.param('weights')
locs = pyro.param('locs')
scale = pyro.param('scale')
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))
weights = [0.59612644 0.4038736 ]
locs = [10.999848    0.50015897]
scale = 0.708275258541

The model’s weights are as expected, with 3/5 of the data in the first component and 2/3 in the second component. We can also examine the guide’s local assignment_probs variable.

In [9]:
assignment_probs = pyro.param('assignment_probs')
pyplot.figure(figsize=(8, 4), dpi=100).set_facecolor('white')
pyplot.plot(data.data.numpy(), assignment_probs.data.numpy()[:, 0], 'ro',
            label='component with mean {:0.2g}'.format(locs[0]))
pyplot.plot(data.data.numpy(), assignment_probs.data.numpy()[:, 1], 'bo',
            label='component with mean {:0.2g}'.format(locs[1]))
pyplot.title('Mixture assignment probabilities')
pyplot.xlabel('data value')
pyplot.ylabel('assignment probability')
pyplot.legend(loc='center');
_images/gmm_15_0.png

Next let’s visualize the mixture model.

In [10]:
X = np.arange(-3,15,0.1)
Y1 = weights[0].item() * scipy.stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = weights[1].item() * scipy.stats.norm.pdf((X - locs[1].item()) / scale.item())

pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white')
pyplot.plot(X, Y1, 'r-')
pyplot.plot(X, Y2, 'b-')
pyplot.plot(X, Y1 + Y2, 'k--')
pyplot.plot(data.data.numpy(), np.zeros(len(data)), 'k*')
pyplot.title('Densitiy of two-component mixture model')
pyplot.ylabel('probability density');
_images/gmm_17_0.png

Finally note that optimization with mixture models is non-convex and can often get stuck in local optima. For example in this tutorial, we observed that the mixture model gets stuck in an everthing-in-one-cluster hypothesis if scale is initialized to be too large.