{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MLE and MAP Estimation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this short tutorial we review how to do Maximum Likelihood (MLE) and Maximum a Posteriori (MAP) estimation in Pyro." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.distributions import constraints\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import SVI, Trace_ELBO\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We consider the simple \"fair coin\" example covered in a [previous tutorial](http://pyro.ai/examples/svi_part_i.html#A-simple-example)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "data = torch.zeros(10)\n", "data[0:6] = 1.0\n", "\n", "def original_model(data):\n", " f = pyro.sample(\"latent_fairness\", dist.Beta(10.0, 10.0))\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To facilitate comparison between different inference techniques, we construct a training helper:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def train(model, guide, lr=0.005, n_steps=201):\n", " pyro.clear_param_store()\n", " adam_params = {\"lr\": lr}\n", " adam = pyro.optim.Adam(adam_params)\n", " svi = SVI(model, guide, adam, loss=Trace_ELBO())\n", "\n", " for step in range(n_steps):\n", " loss = svi.step(data)\n", " if step % 50 == 0:\n", " print('[iter {}] loss: {:.4f}'.format(step, loss))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MLE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our model has a single latent variable `latent_fairness`. To do Maximum Likelihood Estimation we simply \"demote\" our latent variable `latent_fairness` to a Pyro parameter." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def model_mle(data):\n", " # note that we need to include the interval constraint; \n", " # in original_model() this constraint appears implicitly in \n", " # the support of the Beta distribution.\n", " f = pyro.param(\"latent_fairness\", torch.tensor(0.5), \n", " constraint=constraints.unit_interval)\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can render our model as shown below." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_data\n", "\n", "data\n", "\n", "\n", "\n", "latent_fairness\n", "\n", "latent_fairness\n", "\n", "\n", "\n", "obs\n", "\n", "obs\n", "\n", "\n", "\n", "latent_fairness->obs\n", "\n", "\n", "\n", "\n", "\n", "distribution_description_node\n", "obs ~ Bernoulli\n", "latent_fairness ∈ Interval(lower_bound=0.0, upper_bound=1.0)\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(model_mle, model_args=(data,), render_distributions=True, render_params=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we no longer have any latent variables, our guide can be empty:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def guide_mle(data):\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what result we get." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 6.9315\n", "[iter 50] loss: 6.7693\n", "[iter 100] loss: 6.7333\n", "[iter 150] loss: 6.7302\n", "[iter 200] loss: 6.7301\n" ] } ], "source": [ "train(model_mle, guide_mle)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our MLE estimate of the latent fairness is 0.600\n" ] } ], "source": [ "mle_estimate = pyro.param(\"latent_fairness\").item()\n", "print(\"Our MLE estimate of the latent fairness is {:.3f}\".format(mle_estimate))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also compare our MLE estimate with the analytical MLE estimate which is given as: $\\frac{\\#Heads}{\\#Heads + \\#Tails}$. As we encode `Heads` as 1 and `Tails` as 0, we can directly find the analytical MLE as `data.sum()/data.size(0)` or `data.mean()`." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The analytical MLE estimate of the latent fairness is 0.600\n" ] } ], "source": [ "print(\"The analytical MLE estimate of the latent fairness is {:.3f}\".format(\n", " data.mean()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thus with MLE we get a point estimate of `latent_fairness` which matches the analytical MLE estimate." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may be wondering how to interpret the loss numbers in our experiment above. The loss is equivalent to the negative log likelihood (NLL) of observing the data under the Bernoulli likelihood. Thus, the above procedure was equivalent to minimizing the NLL. We confirm the same below." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The negative log likelihood given latent fairness = 0.600 is 6.7301 which matches the loss obtained via our training procedure.\n" ] } ], "source": [ "nll = -dist.Bernoulli(mle_estimate).log_prob(data).sum()\n", "print(f\"The negative log likelihood given latent fairness = {mle_estimate:0.3f} is {nll:0.4f} which matches the loss obtained via our training procedure.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MAP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With Maximum a Posteriori estimation, we also get a point estimate of our latent variables. The difference to MLE is that these estimates will be regularized by the prior. We can understand the difference between the model we use for MLE and MAP via the rendering below, where we can see `latent_fairness` is a `pyro.sample` in original model." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_data\n", "\n", "data\n", "\n", "\n", "\n", "latent_fairness\n", "\n", "latent_fairness\n", "\n", "\n", "\n", "obs\n", "\n", "obs\n", "\n", "\n", "\n", "latent_fairness->obs\n", "\n", "\n", "\n", "\n", "\n", "distribution_description_node\n", "latent_fairness ~ Beta\n", "obs ~ Bernoulli\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(original_model, model_args=(data,), render_distributions=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To do MAP in Pyro we use a [Delta distribution](http://docs.pyro.ai/en/stable/distributions.html#pyro.distributions.Delta) for the guide. Recall that the `Delta` distribution puts all its probability mass at a single value. The `Delta` distribution will be parameterized by a learnable parameter. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def guide_map(data):\n", " f_map = pyro.param(\"f_map\", torch.tensor(0.5),\n", " constraint=constraints.unit_interval)\n", " pyro.sample(\"latent_fairness\", dist.Delta(f_map))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how this result differs from MLE." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 5.6719\n", "[iter 50] loss: 5.6007\n", "[iter 100] loss: 5.6004\n", "[iter 150] loss: 5.6004\n", "[iter 200] loss: 5.6004\n" ] } ], "source": [ "train(original_model, guide_map)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our MAP estimate of the latent fairness is 0.536\n" ] } ], "source": [ "map_estimate = pyro.param(\"f_map\").item()\n", "print(\"Our MAP estimate of the latent fairness is {:.3f}\".format(map_estimate))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand what's going on note that the prior mean of the `latent_fairness` in our model is 0.5, since that is the mean of `Beta(10.0, 10.0)`. The MLE estimate (which ignores the prior) gives us a result that is entirely determined by the raw counts (6 heads and 4 tails). In contrast the MAP estimate is regularized towards the prior mean, which is why the MAP estimate is somewhere between 0.5 and 0.6. We can also understand these from the plot below. Infact, we can also analytically calculate the MAP estimate given the `Beta` prior and `Bernoulli` likelihood.\n", "\n", "Our `Beta` prior is parameterised by $\\alpha_{Heads}$ (= 10 in our example) and $\\alpha_{Tails}$ (= 10 in our example). The closed form expression for MAP estimate is:\n", "$\\frac{\\alpha_{Heads} + ~\\#Heads}{\\alpha_{Heads} + ~\\#Heads +~ \\alpha_{Tails} + ~\\#Tails}$ = $\\frac{10 + 6}{10 + 6 + 10 + 4}$ = $\\frac{16}{30} = 0.5\\bar{3}$" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = torch.linspace(0.0, 1.0, 100)\n", "plt.plot(x, dist.Beta(10, 10).log_prob(x).exp(), label='Prior~Beta(10, 10)')\n", "plt.xlabel(\"Latent Fairness\")\n", "plt.axvline(mle_estimate, color='k', linestyle='--', label='MLE')\n", "plt.axvline(map_estimate, color='r', linestyle='-.', label='MAP')\n", "plt.axvline(0.5, color='g', linestyle=':', label='Prior Mean')\n", "plt.legend(bbox_to_anchor=(1.44,1), borderaxespad=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Doing the same thing with AutoGuides\n", "\n", "In the above we defined guides by hand. \n", "It's often much easier to rely on Pyro's [AutoGuide machinery](https://docs.pyro.ai/en/stable/infer.autoguide.html?highlight=autoguide). \n", "Let's see how we can do MLE and MAP inference using AutoGuides.\n", "\n", "To do MLE estimation we first use [`mask(False)`](https://docs.pyro.ai/en/stable/poutine.html?highlight=mask#pyro.poutine.handlers.mask) to instruct Pyro to ignore the `log_prob` of the latent variable `latent_fairness` in the model. \n", "(Note we need to do this for every latent variable.)\n", "This way the only non-zero `log_prob` in the model will be from the Bernoulli likelihood and ELBO maximization will be equivalent to likelihood maximization." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "def masked_model(data):\n", " f = pyro.sample(\"latent_fairness\", \n", " dist.Beta(10.0, 10.0).mask(False))\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we define an [`AutoDelta`](https://docs.pyro.ai/en/stable/infer.autoguide.html?highlight=autodelta#autodelta) guide, which learns a point estimate for each latent variable (i.e. we do not learn any uncertainty):" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 7.0436\n", "[iter 50] loss: 6.8213\n", "[iter 100] loss: 6.7467\n", "[iter 150] loss: 6.7319\n", "[iter 200] loss: 6.7302\n", "Our MLE estimate of the latent fairness is 0.598\n" ] } ], "source": [ "autoguide_mle = pyro.infer.autoguide.AutoDelta(masked_model)\n", "train(masked_model, autoguide_mle)\n", "print(\"Our MLE estimate of the latent fairness is {:.3f}\".format(\n", " autoguide_mle.median(data)[\"latent_fairness\"].item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To do MAP inference we again use an `AutoDelta` guide but this time on the original model in which `latent_fairness` is a latent variable:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 5.6889\n", "[iter 50] loss: 5.6005\n", "[iter 100] loss: 5.6004\n", "[iter 150] loss: 5.6004\n", "[iter 200] loss: 5.6004\n", "Our MAP estimate of the latent fairness is 0.536\n" ] } ], "source": [ "autoguide_map = pyro.infer.autoguide.AutoDelta(original_model)\n", "train(original_model, autoguide_map)\n", "print(\"Our MAP estimate of the latent fairness is {:.3f}\".format(\n", " autoguide_map.median(data)[\"latent_fairness\"].item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also quickly verify that had we chosen a uniform prior in our original model, our MAP estimate would be the same as the MLE estimate." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "def original_model_uniform_prior(data):\n", " f = pyro.sample(\"latent_fairness\", dist.Uniform(low=0.0, high=1.0))\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 6.7490\n", "[iter 50] loss: 6.7302\n", "[iter 100] loss: 6.7301\n", "[iter 150] loss: 6.7301\n", "[iter 200] loss: 6.7301\n", "Our MAP estimate of the latent fairness under the Uniform prior is 0.600 matching the MLE estimate\n" ] } ], "source": [ "autoguide_map_uniform_prior = pyro.infer.autoguide.AutoDelta(original_model_uniform_prior)\n", "train(original_model_uniform_prior, autoguide_map_uniform_prior)\n", "print(\"Our MAP estimate of the latent fairness under the Uniform prior is {:.3f} matching the MLE estimate\".format(\n", " autoguide_map_uniform_prior.median(data)[\"latent_fairness\"].item()))" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:root] *", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }