{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6oXxxX9LZL-h" }, "source": [ "# Dirichlet Process Mixture Models in Pyro\n", "\n", "\n", "## What are Bayesian nonparametric models?\n", "Bayesian nonparametric models are models where the number of parameters grow freely with the amount of data provided; thus, instead of training several models that vary in complexity and comparing them, one is able to design a model whose complexity grows as more data are observed. The prototypical example of Bayesian nonparametrics in practice is the *Dirichlet Process Mixture Model* (DPMM). A DPMM allows for a practitioner to build a mixture model when the number of distinct clusters in the geometric structure of their data is unknown – in other words, the number of clusters is allowed to grow as more data is observed. This feature makes the DPMM highly useful towards exploratory data analysis, where few facets of the data in question are known; this presentation aims to demonstrate this fact.\n", "\n", "## The Dirichlet Process (Ferguson, 1973)\n", "Dirichlet processes are a family of probability distributions over discrete probability distributions. Formally, the Dirichlet process (DP) is specified by some base probability distribution $G_0: \\Omega \\to \\mathbb{R}$ and a positive, real, scaling parameter commonly denoted as $\\alpha$. A sample $G$ from a Dirichlet process with parameters $G_0: \\Omega \\to \\mathbb{R}$ and $\\alpha$ is itself a distribution over $\\Omega$. For any disjoint partition $\\Omega_1, ..., \\Omega_k$ of $\\Omega$, and any sample $G \\sim DP(G_0, \\alpha)$, we have:\n", "\n", "$$(G(\\Omega_1), ..., G(\\Omega_k)) \\sim \\text{Dir}(\\alpha G_0(\\Omega_1), ..., \\alpha G_0(\\Omega_k))$$\n", "\n", "Essentially, this is taking a discrete partition of our sample space $\\Omega$ and subsequently constructing a discrete distribution over it using the base distribution $G_0$. While quite abstract in formulation, the Dirichlet process is very useful as a prior in various graphical models. This fact becomes easier to see in the following scheme.\n", "\n", "## The Chinese Restaurant Process (Aldous, 1985)\n", "\n", "Imagine a restaurant with infinite tables (indexed by the positive integers) that accepts customers one at a time. The $n$th customer chooses their seat according to the following probabilities:\n", "\n", "* With probability $\\frac{n_t}{\\alpha + n - 1}$, sit at table $t$, where $n_t$ is the number of people at table $t$\n", "* With probability $\\frac{\\alpha}{\\alpha + n - 1}$, sit at an empty table\n", "\n", "If we associate to each table $t$ a draw from a base distribution $G_0$ over $\\Omega$, and then associate unnormalized probability mass $n_t$ to that draw, the resulting distribution over $\\Omega$ is equivalent to a draw from a Dirichlet process $DP(G_0, \\alpha)$. \n", "\n", "Furthermore, we can easily extend this to define the generative process of a nonparametric mixture model: every table $t$ that has at least one customer seated is associated with a set of cluster parameters $\\theta_t$, which were themselves drawn from some base distribution $G_0$. For each new observation, first assign that observation to a table according to the above probabilities; then, that observation is drawn from the distribution parameterized by the cluster parameters for that table. If the observation was assigned to a new table, draw a new set of cluster parameters from $G_0$, and then draw the observation from the distribution parameterized by those cluster parameters.\n", "\n", "While this formulation of a Dirichlet process mixture model is intuitive, it is also very difficult to perform inference on in a probabilistic programming framework. This motivates an alternative formulation of DPMMs, which has empirically been shown to be more conducive to inference (e.g. Blei and Jordan, 2004).\n", "\n", "## The Stick-Breaking Method (Sethuraman, 1994)\n", "\n", "The generative process for the stick-breaking formulation of DPMMs proceeds as follows:\n", "\n", "* Draw $\\beta_i \\sim \\text{Beta}(1, \\alpha)$ for $i \\in \\mathbb{N}$\n", "* Draw $\\theta_i \\sim G_0$ for $i \\in \\mathbb{N}$\n", "* Construct the mixture weights $\\pi$ by taking $\\pi_i(\\beta_{1:\\infty}) = \\beta_i \\prod_{j" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),\n", " MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),\n", " MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),\n", " MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))\n", "\n", "plt.scatter(data[:, 0], data[:, 1])\n", "plt.title(\"Data Samples from Mixture of 4 Gaussians\")\n", "plt.show()\n", "N = data.shape[0]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CglLQke4gEYd" }, "source": [ "In this example, the cluster parameters$\\theta_i$are two dimensional vectors describing the means of a multivariate Gaussian with identity covariance. Therefore, the Dirichlet process base distribution$G_0$is also a multivariate Gaussian (i.e. the conjugate prior), although this choice is not as computationally useful, since we are not performing coordinate-ascent variational inference but rather black-box variational inference using Pyro. \n", "\n", "First, let's define the \"stick-breaking\" function that generates our weights, given our samples of$\\beta$:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "2ngrqFlDQYpV" }, "outputs": [], "source": [ "def mix_weights(beta):\n", " beta1m_cumprod = (1 - beta).cumprod(-1)\n", " return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "LAC0bWL6Qcc3" }, "source": [ "Next, let's define our model. It may be helpful to refer the definition of the stick-breaking model presented in the first part of this tutorial. \n", "\n", "Note that all$\\beta_i$samples are conditionally independent, so we model them using a pyro.plate of size T-1; we do the same for all samples of our cluster parameters$\\mu_i$. We then construct a Categorical distribution whose parameters are the mixture weights using our sampled$\\beta$values (line 9) below, and sample the cluster assignment$z_n$for each data point from that Categorical. Finally, we sample our observations from a multivariate Gaussian distribution whose mean is exactly the cluster parameter corresponding to the assignment$z_n$we drew for the point$x_n$. This can be seen in the Pyro code below:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "WfnbSIocRlvQ" }, "outputs": [], "source": [ "def model(data):\n", " with pyro.plate(\"beta_plate\", T-1):\n", " beta = pyro.sample(\"beta\", Beta(1, alpha))\n", "\n", " with pyro.plate(\"mu_plate\", T):\n", " mu = pyro.sample(\"mu\", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))\n", "\n", " with pyro.plate(\"data\", N):\n", " z = pyro.sample(\"z\", Categorical(mix_weights(beta)))\n", " pyro.sample(\"obs\", MultivariateNormal(mu[z], torch.eye(2)), obs=data)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1gBQj5RKRn8Z" }, "source": [ "Now, it's time to define our guide and perform inference. \n", "\n", "The variational family$q(\\beta, \\theta, z)$that we are optimizing over during variational inference is given by:\n", "\n", "$$q(\\beta, \\theta, z) = \\prod_{t=1}^{T-1} q_t(\\beta_t) \\prod_{t=1}^T q_t(\\theta_t) \\prod_{n=1}^N q_n(z_n)$$ \n", "\n", "Note that since we are unable to computationally model the infinite clusters posited by the model, we truncate our variational family at$T$clusters. This does not affect our model; rather, it is a simplification made in the *inference* stage to allow tractability. \n", "\n", "The guide is constructed exactly according to the definition of our variational family$q(\\beta, \\theta, z)$above. We have$T-1$conditionally independent Beta distributions for each$\\beta$sampled in our model,$T$conditionally independent multivariate Gaussians for each cluster parameter$\\mu_i$, and$N$conditionally independent Categorical distributions for each cluster assignment$z_n$.\n", "\n", "Our variational parameters (pyro.param) are therefore the$T-1$many positive scalars that parameterize the second parameter of our variational Beta distributions (the first shape parameter is fixed at$1$, as in the model definition), the$T$many two-dimensional vectors that parameterize our variational multivariate Gaussian distributions (we do not parameterize the covariance matrices of the Gaussians, though this should be done when analyzing a real-world dataset for more flexibility), and the$N$many$T$-dimensional vectors that parameterize our variational Categorical distributions:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "Imw4wcVkT9er" }, "outputs": [], "source": [ "def guide(data):\n", " kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)\n", " tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2), 3 * torch.eye(2)).sample([T]))\n", " phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)\n", "\n", " with pyro.plate(\"beta_plate\", T-1):\n", " q_beta = pyro.sample(\"beta\", Beta(torch.ones(T-1), kappa))\n", "\n", " with pyro.plate(\"mu_plate\", T):\n", " q_mu = pyro.sample(\"mu\", MultivariateNormal(tau, torch.eye(2)))\n", "\n", " with pyro.plate(\"data\", N):\n", " z = pyro.sample(\"z\", Categorical(phi))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "H0He1id0T_bN" }, "source": [ "When performing inference, we set our 'guess' for the maximum number of clusters in the dataset to$T = 6\$. We define the optimization algorithm (pyro.optim.Adam) along with the Pyro SVI object and train the model for 1000 iterations. \n", "\n", "After performing inference, we construct the Bayes estimators of the means (the expected values of each factor in our variational approximation) and plot them in red on top of the original dataset. Note that we also have we removed any clusters that have less than a certain weight assigned to them according to our learned variational distributions, and then re-normalize the weights so that they sum to one:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 372 }, "colab_type": "code", "id": "x1Yidukpd9wO", "outputId": "b0cc290b-3285-4f36-c2a6-7195b6801482" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:15<00:00, 64.86it/s]\n", "100%|██████████| 1000/1000 [00:15<00:00, 65.47it/s]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "