{
"cells": [
{
"cell_type": "markdown",
"id": "857effec",
"metadata": {},
"source": [
"# Automatic rendering of Pyro models\n",
"\n",
"In this tutorial we will demonstrate how to create beautiful visualizations of your probabilistic graphical models using [pyro.render_model()](https://docs.pyro.ai/en/latest/infer.util.html#pyro.infer.inspect.render_model)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8a068eb0",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.distributions.constraints as constraints\n",
"\n",
"smoke_test = ('CI' in os.environ)\n",
"assert pyro.__version__.startswith('1.9.0')"
]
},
{
"cell_type": "markdown",
"id": "d1266c17",
"metadata": {},
"source": [
"## A Simple Example\n",
"\n",
"The visualization interface can be readily used with your models:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "855a7d8f",
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
" m = pyro.sample(\"m\", dist.Normal(0, 1))\n",
" sd = pyro.sample(\"sd\", dist.LogNormal(m, 1))\n",
" with pyro.plate(\"N\", len(data)):\n",
" pyro.sample(\"obs\", dist.Normal(m, sd), obs=data)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e1e9628e",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = torch.ones(10)\n",
"pyro.render_model(model, model_args=(data,))"
]
},
{
"cell_type": "markdown",
"id": "eb1f65be",
"metadata": {},
"source": [
"The visualization can be saved to a file by providing `filename='path'` to `pyro.render_model`. You can use different formats such as PDF or PNG by changing the filename's suffix.\n",
"When not saving to a file (`filename=None`), you can also change the format with `graph.format = 'pdf'` where `graph` is the object returned by `pyro.render_model`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "700a0917",
"metadata": {},
"outputs": [],
"source": [
"graph = pyro.render_model(model, model_args=(data,), filename=\"model.pdf\")"
]
},
{
"cell_type": "markdown",
"id": "3a4ea614",
"metadata": {},
"source": [
"## Tweaking the visualization\n",
"\n",
"As `pyro.render_model` returns an object of type `graphviz.dot.Digraph`, you can further improve the visualization of this graph.\n",
"For example, you could use the [unflatten preprocessor](https://graphviz.readthedocs.io/en/stable/api.html#graphviz.unflatten) to improve the layout aspect ratio for more complex models."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "01a4b74b",
"metadata": {},
"outputs": [],
"source": [
"def mace(positions, annotations):\n",
" \"\"\"\n",
" This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.\n",
" \"\"\"\n",
" num_annotators = int(torch.max(positions)) + 1\n",
" num_classes = int(torch.max(annotations)) + 1\n",
" num_items, num_positions = annotations.shape\n",
"\n",
" with pyro.plate(\"annotator\", num_annotators):\n",
" epsilon = pyro.sample(\"ε\", dist.Dirichlet(torch.full((num_classes,), 10.)))\n",
" theta = pyro.sample(\"θ\", dist.Beta(0.5, 0.5))\n",
"\n",
" with pyro.plate(\"item\", num_items, dim=-2):\n",
" # NB: using constant logits for discrete uniform prior\n",
" # (NumPyro does not have DiscreteUniform distribution yet)\n",
" c = pyro.sample(\"c\", dist.Categorical(logits=torch.zeros(num_classes)))\n",
"\n",
" with pyro.plate(\"position\", num_positions):\n",
" s = pyro.sample(\"s\", dist.Bernoulli(1 - theta[positions]))\n",
" probs = torch.where(\n",
" s[..., None] == 0, F.one_hot(c, num_classes).float(), epsilon[positions]\n",
" )\n",
" pyro.sample(\"y\", dist.Categorical(probs), obs=annotations)\n",
"\n",
"\n",
"positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])\n",
"# fmt: off\n",
"annotations = torch.tensor([\n",
" [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,\n",
" 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n",
" 1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],\n",
" [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,\n",
" 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,\n",
" 1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],\n",
" [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,\n",
" 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,\n",
" 1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],\n",
" [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,\n",
" 2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,\n",
" 1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],\n",
" [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,\n",
" 1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n",
" 1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],\n",
" [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,\n",
" 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,\n",
" 1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],\n",
" [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,\n",
" 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,\n",
" 1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],\n",
"]).T\n",
"# fmt: on\n",
"\n",
"# we subtract 1 because the first index starts with 0 in Python\n",
"positions -= 1\n",
"annotations -= 1\n",
"\n",
"mace_graph = pyro.render_model(mace, model_args=(positions, annotations))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "311896af",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# default layout\n",
"mace_graph"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8babebb4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# layout after processing the layout with unflatten\n",
"mace_graph.unflatten(stagger=2)"
]
},
{
"cell_type": "markdown",
"id": "50a92902",
"metadata": {},
"source": [
"## Rendering the parameters"
]
},
{
"cell_type": "markdown",
"id": "74f32e20",
"metadata": {},
"source": [
"We can render the parameters defined as `pyro.param` by setting `render_params=True` in `pyro.render_model`. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "645df936",
"metadata": {},
"outputs": [],
"source": [
"def model(data):\n",
" sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n",
" mu = pyro.param(\"mu\", torch.tensor([0.]))\n",
" x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n",
" y = pyro.sample(\"y\", dist.LogNormal(x, 1))\n",
" with pyro.plate(\"N\", len(data)):\n",
" pyro.sample(\"z\", dist.Normal(x, y), obs=data)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "66fc9f55",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = torch.ones(10)\n",
"pyro.render_model(model, model_args=(data,), render_params=True)"
]
},
{
"cell_type": "markdown",
"id": "f09e73e1",
"metadata": {},
"source": [
"## Distribution and Constraint annotations\n",
"\n",
"It is possible to display the distribution of each RV in the generated plot by providing `render_distributions=True` when calling `pyro.render_model`. The constraints associated with parameters are also displayed when `render_distributions=True`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8130359c",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = torch.ones(10)\n",
"pyro.render_model(model, model_args=(data,), render_params=True ,render_distributions=True)"
]
},
{
"cell_type": "markdown",
"id": "3a1e4c4e",
"metadata": {},
"source": [
"In the above plot **'~'** denotes the distribution of RV and **'$\\in$'** denotes the constraint of parameter."
]
},
{
"cell_type": "markdown",
"id": "4dbe4925",
"metadata": {},
"source": [
"## Overlapping non-nested plates\n",
"\n",
"Note that overlapping non-nested plates may be drawn as multiple rectangles."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9f1a4ebf",
"metadata": {},
"outputs": [],
"source": [
"def model():\n",
" plate1 = pyro.plate(\"plate1\", 2, dim=-2)\n",
" plate2 = pyro.plate(\"plate2\", 3, dim=-1)\n",
" with plate1:\n",
" x = pyro.sample(\"x\", dist.Normal(0, 1))\n",
" with plate1, plate2:\n",
" y = pyro.sample(\"y\", dist.Normal(x, 1))\n",
" with plate2:\n",
" pyro.sample(\"z\", dist.Normal(y.sum(-2, True), 1), obs=torch.zeros(3))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8514d7df",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pyro.render_model(model)"
]
},
{
"cell_type": "markdown",
"id": "e7224816",
"metadata": {},
"source": [
"## Semisupervised models\n",
"\n",
"Pyro allows semisupervised models by allowing different sets of `*args,**kwargs` to be passed to a model. You can render semisupervised models by passing a list of different tuples `model_args` and/or a list of different `model_kwargs` to denote the different ways you use a model."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "da332f10",
"metadata": {},
"outputs": [],
"source": [
"def model(x, y=None):\n",
" with pyro.plate(\"N\", 2):\n",
" z = pyro.sample(\"z\", dist.Normal(0, 1))\n",
" y = pyro.sample(\"y\", dist.Normal(0, 1), obs=y)\n",
" pyro.sample(\"x\", dist.Normal(y + z, 1), obs=x)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "20b69fb9",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pyro.render_model(\n",
" model,\n",
" model_kwargs=[\n",
" {\"x\": torch.zeros(2)},\n",
" {\"x\": torch.zeros(2), \"y\": torch.zeros(2)},\n",
" ]\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "837047a8",
"metadata": {},
"source": [
"# Rendering deterministic variables\n",
"\n",
"Pyro allows deterministic variables to be defined using `pyro.deterministic`. These variables can be rendered by setting `render_deterministic=True` in `pyro.render_model` as follows:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d90dc8d7",
"metadata": {},
"outputs": [],
"source": [
"def model_deterministic(data):\n",
" sigma = pyro.param(\"sigma\", torch.tensor([1.]), constraint=constraints.positive)\n",
" mu = pyro.param(\"mu\", torch.tensor([0.]))\n",
" x = pyro.sample(\"x\", dist.Normal(mu, sigma))\n",
" log_y = pyro.sample(\"y\", dist.Normal(x, 1))\n",
" y = pyro.deterministic(\"y_deterministic\", log_y.exp())\n",
" with pyro.plate(\"N\", len(data)):\n",
" eps_z_loc = pyro.sample(\"eps_z_loc\", dist.Normal(0, 1))\n",
" z_loc = pyro.deterministic(\"z_loc\", eps_z_loc + x, event_dim=0)\n",
" pyro.sample(\"z\", dist.Normal(z_loc, y), obs=data)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6fcc43d8",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = torch.ones(10)\n",
"pyro.render_model(\n",
" model_deterministic,\n",
" model_args=(data,),\n",
" render_params=True,\n",
" render_distributions=True,\n",
" render_deterministic=True\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}