Automatic rendering of Pyro models

In this tutorial we will demonstrate how to create beautiful visualizations of your probabilistic graphical models using pyro.render_model().

[1]:
import os
import torch
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.0')

A Simple Example

The visualization interface can be readily used with your models:

[2]:
def model(data):
    m = pyro.sample("m", dist.Normal(0, 1))
    sd = pyro.sample("sd", dist.LogNormal(m, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("obs", dist.Normal(m, sd), obs=data)
[3]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,))
[3]:
_images/model_rendering_4_0.svg

The visualization can be saved to a file by providing filename='path' to pyro.render_model. You can use different formats such as PDF or PNG by changing the filename’s suffix. When not saving to a file (filename=None), you can also change the format with graph.format = 'pdf' where graph is the object returned by pyro.render_model.

[4]:
graph = pyro.render_model(model, model_args=(data,), filename="model.pdf")

Tweaking the visualization

As pyro.render_model returns an object of type graphviz.dot.Digraph, you can further improve the visualization of this graph. For example, you could use the unflatten preprocessor to improve the layout aspect ratio for more complex models.

[5]:
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.
    """
    num_annotators = int(torch.max(positions)) + 1
    num_classes = int(torch.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with pyro.plate("annotator", num_annotators):
        epsilon = pyro.sample("ε", dist.Dirichlet(torch.full((num_classes,), 10.)))
        theta = pyro.sample("θ", dist.Beta(0.5, 0.5))

    with pyro.plate("item", num_items, dim=-2):
        # NB: using constant logits for discrete uniform prior
        # (NumPyro does not have DiscreteUniform distribution yet)
        c = pyro.sample("c", dist.Categorical(logits=torch.zeros(num_classes)))

        with pyro.plate("position", num_positions):
            s = pyro.sample("s", dist.Bernoulli(1 - theta[positions]))
            probs = torch.where(
                s[..., None] == 0, F.one_hot(c, num_classes).float(), epsilon[positions]
            )
            pyro.sample("y", dist.Categorical(probs), obs=annotations)


positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = torch.tensor([
    [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],
    [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,
     2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],
    [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,
     1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],
    [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,
     1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],
]).T
# fmt: on

# we subtract 1 because the first index starts with 0 in Python
positions -= 1
annotations -= 1

mace_graph = pyro.render_model(mace, model_args=(positions, annotations))
[6]:
# default layout
mace_graph
[6]:
_images/model_rendering_9_0.svg
[7]:
# layout after processing the layout with unflatten
mace_graph.unflatten(stagger=2)
[7]:
_images/model_rendering_10_0.svg

Rendering the parameters

We can render the parameters defined as pyro.param by setting render_params=True in pyro.render_model.

[8]:
def model(data):
    sigma = pyro.param("sigma", torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", torch.tensor([0.]))
    x = pyro.sample("x", dist.Normal(mu, sigma))
    y = pyro.sample("y", dist.LogNormal(x, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("z", dist.Normal(x, y), obs=data)
[9]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,), render_params=True)
[9]:
_images/model_rendering_14_0.svg

Distribution and Constraint annotations

It is possible to display the distribution of each RV in the generated plot by providing render_distributions=True when calling pyro.render_model. The constraints associated with parameters are also displayed when render_distributions=True.

[10]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,), render_params=True ,render_distributions=True)
[10]:
_images/model_rendering_16_0.svg

In the above plot ‘~’ denotes the distribution of RV and ‘:math:`in`’ denotes the constraint of parameter.

Overlapping non-nested plates

Note that overlapping non-nested plates may be drawn as multiple rectangles.

[11]:
def model():
    plate1 = pyro.plate("plate1", 2, dim=-2)
    plate2 = pyro.plate("plate2", 3, dim=-1)
    with plate1:
        x = pyro.sample("x", dist.Normal(0, 1))
    with plate1, plate2:
        y = pyro.sample("y", dist.Normal(x, 1))
    with plate2:
        pyro.sample("z", dist.Normal(y.sum(-2, True), 1), obs=torch.zeros(3))
[12]:
pyro.render_model(model)
[12]:
_images/model_rendering_20_0.svg

Semisupervised models

Pyro allows semisupervised models by allowing different sets of *args,**kwargs to be passed to a model. You can render semisupervised models by passing a list of different tuples model_args and/or a list of different model_kwargs to denote the different ways you use a model.

[13]:
def model(x, y=None):
    with pyro.plate("N", 2):
        z = pyro.sample("z", dist.Normal(0, 1))
        y = pyro.sample("y", dist.Normal(0, 1), obs=y)
        pyro.sample("x", dist.Normal(y + z, 1), obs=x)
[14]:
pyro.render_model(
    model,
    model_kwargs=[
        {"x": torch.zeros(2)},
        {"x": torch.zeros(2), "y": torch.zeros(2)},
    ]
)
[14]:
_images/model_rendering_23_0.svg

Rendering deterministic variables

Pyro allows deterministic variables to be defined using pyro.deterministic. These variables can be rendered by setting render_deterministic=True in pyro.render_model as follows:

[15]:
def model_deterministic(data):
    sigma = pyro.param("sigma", torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", torch.tensor([0.]))
    x = pyro.sample("x", dist.Normal(mu, sigma))
    log_y = pyro.sample("y", dist.Normal(x, 1))
    y = pyro.deterministic("y_deterministic", log_y.exp())
    with pyro.plate("N", len(data)):
        eps_z_loc = pyro.sample("eps_z_loc", dist.Normal(0, 1))
        z_loc = pyro.deterministic("z_loc", eps_z_loc + x, event_dim=0)
        pyro.sample("z", dist.Normal(z_loc, y), obs=data)
[16]:
data = torch.ones(10)
pyro.render_model(
    model_deterministic,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
    render_deterministic=True
)
[16]:
_images/model_rendering_26_0.svg