{ "cells": [ { "cell_type": "markdown", "id": "857effec", "metadata": {}, "source": [ "# Automatic rendering of Pyro models\n", "\n", "In this tutorial we will demonstrate how to create beautiful visualizations of your probabilistic graphical models using [pyro.render_model()](https://docs.pyro.ai/en/latest/infer.util.html#pyro.infer.inspect.render_model)." ] }, { "cell_type": "code", "execution_count": 1, "id": "8a068eb0", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn.functional as F\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.distributions.constraints as constraints\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.9.0')" ] }, { "cell_type": "markdown", "id": "d1266c17", "metadata": {}, "source": [ "## A Simple Example\n", "\n", "The visualization interface can be readily used with your models:" ] }, { "cell_type": "code", "execution_count": 2, "id": "855a7d8f", "metadata": {}, "outputs": [], "source": [ "def model(data):\n", " m = pyro.sample(\"m\", dist.Normal(0, 1))\n", " sd = pyro.sample(\"sd\", dist.LogNormal(m, 1))\n", " with pyro.plate(\"N\", len(data)):\n", " pyro.sample(\"obs\", dist.Normal(m, sd), obs=data)" ] }, { "cell_type": "code", "execution_count": 3, "id": "e1e9628e", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nm\n\nm\n\n\n\nsd\n\nsd\n\n\n\nm->sd\n\n\n\n\n\nobs\n\nobs\n\n\n\nm->obs\n\n\n\n\n\nsd->obs\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = torch.ones(10)\n", "pyro.render_model(model, model_args=(data,))" ] }, { "cell_type": "markdown", "id": "eb1f65be", "metadata": {}, "source": [ "The visualization can be saved to a file by providing `filename='path'` to `pyro.render_model`. You can use different formats such as PDF or PNG by changing the filename's suffix.\n", "When not saving to a file (`filename=None`), you can also change the format with `graph.format = 'pdf'` where `graph` is the object returned by `pyro.render_model`." ] }, { "cell_type": "code", "execution_count": 4, "id": "700a0917", "metadata": {}, "outputs": [], "source": [ "graph = pyro.render_model(model, model_args=(data,), filename=\"model.pdf\")" ] }, { "cell_type": "markdown", "id": "3a4ea614", "metadata": {}, "source": [ "## Tweaking the visualization\n", "\n", "As `pyro.render_model` returns an object of type `graphviz.dot.Digraph`, you can further improve the visualization of this graph.\n", "For example, you could use the [unflatten preprocessor](https://graphviz.readthedocs.io/en/stable/api.html#graphviz.unflatten) to improve the layout aspect ratio for more complex models." ] }, { "cell_type": "code", "execution_count": 5, "id": "01a4b74b", "metadata": {}, "outputs": [], "source": [ "def mace(positions, annotations):\n", " \"\"\"\n", " This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.\n", " \"\"\"\n", " num_annotators = int(torch.max(positions)) + 1\n", " num_classes = int(torch.max(annotations)) + 1\n", " num_items, num_positions = annotations.shape\n", "\n", " with pyro.plate(\"annotator\", num_annotators):\n", " epsilon = pyro.sample(\"ε\", dist.Dirichlet(torch.full((num_classes,), 10.)))\n", " theta = pyro.sample(\"θ\", dist.Beta(0.5, 0.5))\n", "\n", " with pyro.plate(\"item\", num_items, dim=-2):\n", " # NB: using constant logits for discrete uniform prior\n", " # (NumPyro does not have DiscreteUniform distribution yet)\n", " c = pyro.sample(\"c\", dist.Categorical(logits=torch.zeros(num_classes)))\n", "\n", " with pyro.plate(\"position\", num_positions):\n", " s = pyro.sample(\"s\", dist.Bernoulli(1 - theta[positions]))\n", " probs = torch.where(\n", " s[..., None] == 0, F.one_hot(c, num_classes).float(), epsilon[positions]\n", " )\n", " pyro.sample(\"y\", dist.Categorical(probs), obs=annotations)\n", "\n", "\n", "positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])\n", "# fmt: off\n", "annotations = torch.tensor([\n", " [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,\n", " 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n", " 1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],\n", " [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,\n", " 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,\n", " 1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],\n", " [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,\n", " 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,\n", " 1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],\n", " [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,\n", " 2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,\n", " 1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],\n", " [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,\n", " 1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n", " 1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],\n", " [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,\n", " 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,\n", " 1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],\n", " [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,\n", " 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,\n", " 1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],\n", "]).T\n", "# fmt: on\n", "\n", "# we subtract 1 because the first index starts with 0 in Python\n", "positions -= 1\n", "annotations -= 1\n", "\n", "mace_graph = pyro.render_model(mace, model_args=(positions, annotations))" ] }, { "cell_type": "code", "execution_count": 6, "id": "311896af", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\ns->y\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# default layout\n", "mace_graph" ] }, { "cell_type": "code", "execution_count": 7, "id": "8babebb4", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_annotator\n\nannotator\n\n\ncluster_item\n\nitem\n\n\ncluster_position\n\nposition\n\n\n\nε\n\nε\n\n\n\ny\n\ny\n\n\n\nε->y\n\n\n\n\n\nθ\n\nθ\n\n\n\ns\n\ns\n\n\n\nθ->s\n\n\n\n\n\ns->y\n\n\n\n\n\nc\n\nc\n\n\n\nc->y\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# layout after processing the layout with unflatten\n", "mace_graph.unflatten(stagger=2)" ] }, { "cell_type": "markdown", "id": "50a92902", "metadata": {}, "source": [ "## Rendering the parameters" ] }, { "cell_type": "markdown", "id": "74f32e20", "metadata": {}, "source": [ "We can render the parameters defined as `pyro.param` by setting `render_params=True` in `pyro.render_model`. " ] }, { "cell_type": "code", "execution_count": 8, "id": "645df936", "metadata": {}, "outputs": [], "source": [ "def model(data):\n", " sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n", " mu = pyro.param(\"mu\", torch.tensor([0.]))\n", " x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n", " y = pyro.sample(\"y\", dist.LogNormal(x, 1))\n", " with pyro.plate(\"N\", len(data)):\n", " pyro.sample(\"z\", dist.Normal(x, y), obs=data)" ] }, { "cell_type": "code", "execution_count": 9, "id": "66fc9f55", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = torch.ones(10)\n", "pyro.render_model(model, model_args=(data,), render_params=True)" ] }, { "cell_type": "markdown", "id": "f09e73e1", "metadata": {}, "source": [ "## Distribution and Constraint annotations\n", "\n", "It is possible to display the distribution of each RV in the generated plot by providing `render_distributions=True` when calling `pyro.render_model`. The constraints associated with parameters are also displayed when `render_distributions=True`." ] }, { "cell_type": "code", "execution_count": 10, "id": "8130359c", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\nx->z\n\n\n\n\n\ny->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ LogNormal\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = torch.ones(10)\n", "pyro.render_model(model, model_args=(data,), render_params=True ,render_distributions=True)" ] }, { "cell_type": "markdown", "id": "3a1e4c4e", "metadata": {}, "source": [ "In the above plot **'~'** denotes the distribution of RV and **'$\\in$'** denotes the constraint of parameter." ] }, { "cell_type": "markdown", "id": "4dbe4925", "metadata": {}, "source": [ "## Overlapping non-nested plates\n", "\n", "Note that overlapping non-nested plates may be drawn as multiple rectangles." ] }, { "cell_type": "code", "execution_count": 11, "id": "9f1a4ebf", "metadata": {}, "outputs": [], "source": [ "def model():\n", " plate1 = pyro.plate(\"plate1\", 2, dim=-2)\n", " plate2 = pyro.plate(\"plate2\", 3, dim=-1)\n", " with plate1:\n", " x = pyro.sample(\"x\", dist.Normal(0, 1))\n", " with plate1, plate2:\n", " y = pyro.sample(\"y\", dist.Normal(x, 1))\n", " with plate2:\n", " pyro.sample(\"z\", dist.Normal(y.sum(-2, True), 1), obs=torch.zeros(3))" ] }, { "cell_type": "code", "execution_count": 12, "id": "8514d7df", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_plate1\n\nplate1\n\n\ncluster_plate2\n\nplate2\n\n\ncluster_plate2__CLONE\n\nplate2\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz\n\nz\n\n\n\ny->z\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(model)" ] }, { "cell_type": "markdown", "id": "e7224816", "metadata": {}, "source": [ "## Semisupervised models\n", "\n", "Pyro allows semisupervised models by allowing different sets of `*args,**kwargs` to be passed to a model. You can render semisupervised models by passing a list of different tuples `model_args` and/or a list of different `model_kwargs` to denote the different ways you use a model." ] }, { "cell_type": "code", "execution_count": 13, "id": "da332f10", "metadata": {}, "outputs": [], "source": [ "def model(x, y=None):\n", " with pyro.plate(\"N\", 2):\n", " z = pyro.sample(\"z\", dist.Normal(0, 1))\n", " y = pyro.sample(\"y\", dist.Normal(0, 1), obs=y)\n", " pyro.sample(\"x\", dist.Normal(y + z, 1), obs=x)" ] }, { "cell_type": "code", "execution_count": 14, "id": "20b69fb9", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nz\n\nz\n\n\n\nx\n\nx\n\n\n\nz->x\n\n\n\n\n\ny\n\n\n\n\n\n\n\ny\n\n\n\ny->x\n\n\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pyro.render_model(\n", " model,\n", " model_kwargs=[\n", " {\"x\": torch.zeros(2)},\n", " {\"x\": torch.zeros(2), \"y\": torch.zeros(2)},\n", " ]\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "837047a8", "metadata": {}, "source": [ "# Rendering deterministic variables\n", "\n", "Pyro allows deterministic variables to be defined using `pyro.deterministic`. These variables can be rendered by setting `render_deterministic=True` in `pyro.render_model` as follows:" ] }, { "cell_type": "code", "execution_count": 15, "id": "d90dc8d7", "metadata": {}, "outputs": [], "source": [ "def model_deterministic(data):\n", " sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n", " mu = pyro.param(\"mu\", torch.tensor([0.]))\n", " x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n", " log_y = pyro.sample(\"y\", dist.Normal(x, 1))\n", " y = pyro.deterministic(\"y_deterministic\", log_y.exp())\n", " with pyro.plate(\"N\", len(data)):\n", " eps_z_loc = pyro.sample(\"eps_z_loc\", dist.Normal(0, 1))\n", " z_loc = pyro.deterministic(\"z_loc\", eps_z_loc + x, event_dim=0)\n", " pyro.sample(\"z\", dist.Normal(z_loc, y), obs=data)" ] }, { "cell_type": "code", "execution_count": 16, "id": "6fcc43d8", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": "\n\n\n\n\n\n\n\ncluster_N\n\nN\n\n\n\nx\n\nx\n\n\n\ny\n\ny\n\n\n\nx->y\n\n\n\n\n\nz_loc\n\nz_loc\n\n\n\nx->z_loc\n\n\n\n\n\ny_deterministic\n\ny_deterministic\n\n\n\ny->y_deterministic\n\n\n\n\n\nz\n\nz\n\n\n\ny_deterministic->z\n\n\n\n\n\nsigma\n\nsigma\n\n\n\nsigma->x\n\n\n\n\n\nmu\n\nmu\n\n\n\nmu->x\n\n\n\n\n\neps_z_loc\n\neps_z_loc\n\n\n\neps_z_loc->z_loc\n\n\n\n\n\nz_loc->z\n\n\n\n\n\ndistribution_description_node\nx ~ Normal\ny ~ Normal\ny_deterministic ~ Deterministic\neps_z_loc ~ Normal\nz_loc ~ Deterministic\nz ~ Normal\nsigma : GreaterThan(lower_bound=0.0)\nmu : Real()\n\n\n\n", "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = torch.ones(10)\n", "pyro.render_model(\n", " model_deterministic,\n", " model_args=(data,),\n", " render_params=True,\n", " render_distributions=True,\n", " render_deterministic=True\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }