{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Writing guides using EasyGuide\n", "\n", "This tutorial describes the [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html) module. This tutorial assumes the reader is already familiar with [SVI](http://pyro.ai/examples/svi_part_ii.html) and [tensor shapes](http://pyro.ai/examples/tensor_shapes.html).\n", "\n", "#### Summary\n", "\n", "- For simple black-box guides, try using components in [pyro.infer.autoguide](http://docs.pyro.ai/en/stable/infer.autoguide.html).\n", "- For more complex guides, try using components in [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html).\n", "- Decorate with `@easy_guide(model)`.\n", "- Select multiple model sites using `group = self.group(match=\"my_regex\")`.\n", "- Guide a group of sites by a single distribution using `group.sample(...)`.\n", "- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.\n", "- Use `self.plate(...)` instead of `pyro.plate(...)`.\n", "- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.\n", "- To MAP estimate model site \"foo\", use `foo = self.map_estimate(\"foo\")`.\n", "\n", "#### Table of contents\n", "\n", "- [Modeling time series data](#Modeling-time-series-data)\n", "- [Writing a guide without EasyGuide](#Writing-a-guide-without-EasyGuide)\n", "- [Using EasyGuide](#Using-EasyGuide)\n", "- [Amortized guides](#Amortized-guides)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.contrib.easyguide import easy_guide\n", "from pyro.optim import Adam\n", "from torch.distributions import constraints\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.9.0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling time series data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def model(batch, subsample, full_size):\n", " batch = list(batch)\n", " num_time_steps = len(batch)\n", " drift = pyro.sample(\"drift\", dist.LogNormal(-1, 0.5))\n", " with pyro.plate(\"data\", full_size, subsample=subsample):\n", " z = 0.\n", " for t in range(num_time_steps):\n", " z = pyro.sample(\"state_{}\".format(t),\n", " dist.Normal(z, drift))\n", " batch[t] = pyro.sample(\"obs_{}\".format(t),\n", " dist.Bernoulli(logits=z),\n", " obs=batch[t])\n", " return torch.stack(batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's generate some data directly from the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "full_size = 100\n", "num_time_steps = 7\n", "pyro.set_rng_seed(123456789)\n", "data = model([None] * num_time_steps, torch.arange(full_size), full_size)\n", "assert data.shape == (num_time_steps, full_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Writing a guide without EasyGuide\n", "\n", "Consider a possible guide for this model where we point-estimate the `drift` parameter using a `Delta` distribution, and then model local time series using shared uncertainty but local means, using a `LowRankMultivariateNormal` distribution. There is a single global sample site which we can model with a `param` and `sample` statement. Then we sample a global pair of uncertainty parameters `cov_diag` and `cov_factor`. Next we sample a local `loc` parameter using `pyro.param(..., event_dim=...)` and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-`Delta`s pattern is quite common." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rank = 3\n", " \n", "def guide(batch, subsample, full_size):\n", " num_time_steps, batch_size = batch.shape\n", "\n", " # MAP estimate the drift.\n", " drift_loc = pyro.param(\"drift_loc\", lambda: torch.tensor(0.1),\n", " constraint=constraints.positive)\n", " pyro.sample(\"drift\", dist.Delta(drift_loc))\n", "\n", " # Model local states using shared uncertainty + local mean.\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full((num_time_steps,), 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(num_time_steps, rank) * 0.01)\n", " with pyro.plate(\"data\", full_size, subsample=subsample):\n", " # Sample local mean.\n", " loc = pyro.param(\"state_loc\",\n", " lambda: torch.full((full_size, num_time_steps), 0.5),\n", " event_dim=1)\n", " states = pyro.sample(\"states\",\n", " dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),\n", " infer={\"is_auxiliary\": True})\n", " # Unpack the joint states into one sample site per time step.\n", " for t in range(num_time_steps):\n", " pyro.sample(\"state_{}\".format(t), dist.Delta(states[:, t]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's train using [SVI](http://docs.pyro.ai/en/stable/inference_algos.html#module-pyro.infer.svi) and [Trace_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.trace_elbo.Trace_ELBO), manually batching data into small minibatches." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):\n", " full_size = data.size(-1)\n", " pyro.get_param_store().clear()\n", " pyro.set_rng_seed(123456789)\n", " svi = SVI(model, guide, Adam({\"lr\": 0.02}), Trace_ELBO())\n", " for epoch in range(num_epochs):\n", " pos = 0\n", " losses = []\n", " while pos < full_size:\n", " subsample = torch.arange(pos, pos + batch_size)\n", " batch = data[:, pos:pos + batch_size]\n", " pos += batch_size\n", " losses.append(svi.step(batch, subsample, full_size=full_size))\n", " epoch_loss = sum(losses) / len(losses)\n", " if epoch % 10 == 0:\n", " print(\"epoch {} loss = {}\".format(epoch, epoch_loss / data.numel()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train(guide)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using EasyGuide\n", "\n", "Now let's simplify using the `@easy_guide` decorator. Our modifications are:\n", "1. Decorate with `@easy_guide` and add `self` to args.\n", "2. Replace the `Delta` guide for drift with a simple `map_estimate()`.\n", "3. Select a `group` of model sites and read their concatenated `event_shape`.\n", "4. Replace the auxiliary site and `Delta` slices with a single `group.sample()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@easy_guide(model)\n", "def guide(self, batch, subsample, full_size):\n", " # MAP estimate the drift.\n", " self.map_estimate(\"drift\")\n", "\n", " # Model local states using shared uncertainty + local mean.\n", " group = self.group(match=\"state_[0-9]*\") # Selects all local variables.\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full(group.event_shape, 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n", " with self.plate(\"data\", full_size, subsample=subsample):\n", " # Sample local mean.\n", " loc = pyro.param(\"state_loc\",\n", " lambda: torch.full((full_size,) + group.event_shape, 0.5),\n", " event_dim=1)\n", " # Automatically sample the joint latent, then unpack and replay model sites.\n", " group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note we've used `group.event_shape` to determine the total flattened concatenated shape of all matched sites in the group." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train(guide)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Amortized guides\n", "\n", "`EasyGuide` also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let's modify the last guide to predict the latent `loc` as an affine function of observed data, rather than memorizing each data point's latent variable. This amortized guide is more useful in practice because it can handle new data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@easy_guide(model)\n", "def guide(self, batch, subsample, full_size):\n", " num_time_steps, batch_size = batch.shape\n", " self.map_estimate(\"drift\")\n", "\n", " group = self.group(match=\"state_[0-9]*\")\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full(group.event_shape, 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n", "\n", " # Predict latent propensity as an affine function of observed data.\n", " if not hasattr(self, \"nn\"):\n", " self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())\n", " self.nn.weight.data.fill_(1.0 / num_time_steps)\n", " self.nn.bias.data.fill_(-0.5)\n", " pyro.module(\"state_nn\", self.nn)\n", " with self.plate(\"data\", full_size, subsample=subsample):\n", " loc = self.nn(batch.t())\n", " group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train(guide)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }