{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MLE and MAP Estimation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this short tutorial we review how to do Maximum Likelihood (MLE) and Maximum a Posteriori (MAP) estimation in Pyro." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.distributions import constraints\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import SVI, Trace_ELBO\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We consider the simple \"fair coin\" example covered in a [previous tutorial](http://pyro.ai/examples/svi_part_i.html#A-simple-example)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "data = torch.zeros(10)\n", "data[0:6] = 1.0\n", "\n", "def original_model(data):\n", " f = pyro.sample(\"latent_fairness\", dist.Beta(10.0, 10.0))\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To facilitate comparison between different inference techniques, we construct a training helper:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def train(model, guide, lr=0.005, n_steps=201):\n", " pyro.clear_param_store()\n", " adam_params = {\"lr\": lr}\n", " adam = pyro.optim.Adam(adam_params)\n", " svi = SVI(model, guide, adam, loss=Trace_ELBO())\n", "\n", " for step in range(n_steps):\n", " loss = svi.step(data)\n", " if step % 50 == 0:\n", " print('[iter {}] loss: {:.4f}'.format(step, loss))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MLE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our model has a single latent variable latent_fairness. To do Maximum Likelihood Estimation we simply \"demote\" our latent variable latent_fairness to a Pyro parameter." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def model_mle(data):\n", " # note that we need to include the interval constraint; \n", " # in original_model() this constraint appears implicitly in \n", " # the support of the Beta distribution.\n", " f = pyro.param(\"latent_fairness\", torch.tensor(0.5), \n", " constraint=constraints.unit_interval)\n", " with pyro.plate(\"data\", data.size(0)):\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can render our model as shown below." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(model_mle, model_args=(data,), render_distributions=True, render_params=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we no longer have any latent variables, our guide can be empty:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def guide_mle(data):\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what result we get." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 6.9315\n", "[iter 50] loss: 6.7693\n", "[iter 100] loss: 6.7333\n", "[iter 150] loss: 6.7302\n", "[iter 200] loss: 6.7301\n" ] } ], "source": [ "train(model_mle, guide_mle)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our MLE estimate of the latent fairness is 0.600\n" ] } ], "source": [ "mle_estimate = pyro.param(\"latent_fairness\").item()\n", "print(\"Our MLE estimate of the latent fairness is {:.3f}\".format(mle_estimate))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also compare our MLE estimate with the analytical MLE estimate which is given as: $\\frac{\\#Heads}{\\#Heads + \\#Tails}$. As we encode Heads as 1 and Tails as 0, we can directly find the analytical MLE as data.sum()/data.size(0) or data.mean()." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The analytical MLE estimate of the latent fairness is 0.600\n" ] } ], "source": [ "print(\"The analytical MLE estimate of the latent fairness is {:.3f}\".format(\n", " data.mean()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thus with MLE we get a point estimate of latent_fairness which matches the analytical MLE estimate." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may be wondering how to interpret the loss numbers in our experiment above. The loss is equivalent to the negative log likelihood (NLL) of observing the data under the Bernoulli likelihood. Thus, the above procedure was equivalent to minimizing the NLL. We confirm the same below." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The negative log likelihood given latent fairness = 0.600 is 6.7301 which matches the loss obtained via our training procedure.\n" ] } ], "source": [ "nll = -dist.Bernoulli(mle_estimate).log_prob(data).sum()\n", "print(f\"The negative log likelihood given latent fairness = {mle_estimate:0.3f} is {nll:0.4f} which matches the loss obtained via our training procedure.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MAP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With Maximum a Posteriori estimation, we also get a point estimate of our latent variables. The difference to MLE is that these estimates will be regularized by the prior. We can understand the difference between the model we use for MLE and MAP via the rendering below, where we can see latent_fairness is a pyro.sample in original model." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(original_model, model_args=(data,), render_distributions=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To do MAP in Pyro we use a [Delta distribution](http://docs.pyro.ai/en/stable/distributions.html#pyro.distributions.Delta) for the guide. Recall that the Delta distribution puts all its probability mass at a single value. The Delta distribution will be parameterized by a learnable parameter. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def guide_map(data):\n", " f_map = pyro.param(\"f_map\", torch.tensor(0.5),\n", " constraint=constraints.unit_interval)\n", " pyro.sample(\"latent_fairness\", dist.Delta(f_map))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how this result differs from MLE." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[iter 0] loss: 5.6719\n", "[iter 50] loss: 5.6007\n", "[iter 100] loss: 5.6004\n", "[iter 150] loss: 5.6004\n", "[iter 200] loss: 5.6004\n" ] } ], "source": [ "train(original_model, guide_map)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Our MAP estimate of the latent fairness is 0.536\n" ] } ], "source": [ "map_estimate = pyro.param(\"f_map\").item()\n", "print(\"Our MAP estimate of the latent fairness is {:.3f}\".format(map_estimate))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand what's going on note that the prior mean of the latent_fairness in our model is 0.5, since that is the mean of Beta(10.0, 10.0). The MLE estimate (which ignores the prior) gives us a result that is entirely determined by the raw counts (6 heads and 4 tails). In contrast the MAP estimate is regularized towards the prior mean, which is why the MAP estimate is somewhere between 0.5 and 0.6. We can also understand these from the plot below. Infact, we can also analytically calculate the MAP estimate given the Beta prior and Bernoulli likelihood.\n", "\n", "Our Beta prior is parameterised by $\\alpha_{Heads}$ (= 10 in our example) and $\\alpha_{Tails}$ (= 10 in our example). The closed form expression for MAP estimate is:\n", "$\\frac{\\alpha_{Heads} + ~\\#Heads}{\\alpha_{Heads} + ~\\#Heads +~ \\alpha_{Tails} + ~\\#Tails}$ = $\\frac{10 + 6}{10 + 6 + 10 + 4}$ = $\\frac{16}{30} = 0.5\\bar{3}$" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "