Introduction to Pyro

Probability is the mathematics of reasoning under uncertainty, much as calculus is the mathematics for reasoning about rates of change. It provides a unifying theoretical framework for understanding much of modern machine learning and AI: models built in the language of probability can capture complex reasoning, know what they do not know, and uncover structure in data without supervision.

Specifying probabilistic models directly can be cumbersome and implementing them can be very error-prone. Probabilistic programming languages (PPLs) solve these problems by marrying probability with the representational power of programming languages. A probabilistic program is a mix of ordinary deterministic computation and randomly sampled values representing a generative process for data.

By observing the outcome of a probabilistic program, we can describe an inference problem, roughly translated as: “what must be true if this random choice had a certain observed value?” PPLs explicitly enforce a separation of concerns already implicit in the mathematics of probability between the specification of a model, a query to be answered, and an algorithm for computing the answer.

Pyro is a probabilistic programming language built on Python and PyTorch. Pyro programs are just Python programs, while its main inference technology is stochastic variational inference, which converts abstract probabilistic computations into concrete optimization problems solved with stochastic gradient descent in PyTorch, making probabilistic methods applicable to previously intractable model and dataset sizes.

In this tutorial, we take a brief, opinionated tour of the basic concepts of probabilistic machine learning and probabilistic programming with Pyro. We do so via an example data analysis problem involving linear regression, one of the most common and basic tasks in machine learning. We will see how to use Pyro’s modeling language and inference algorithms to incorporate uncertainty into estimates of regression coefficients.

Outline

Setup

Let’s begin by importing the modules we’ll need.

[41]:
%reset -s -f
[42]:
import logging
import os

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
[43]:
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.6')

pyro.enable_validation(True)
pyro.set_rng_seed(1)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
%matplotlib inline
plt.style.use('default')

Background: probabilistic machine learning

Most data analysis problems can be understood as elaborations on three basic high-level questions:

  1. What do we know about the problem before observing any data?

  2. What conclusions can we draw from data given our prior knowledge?

  3. Do these conclusions make sense?

In the probabilistic or Bayesian approach to data science and machine learning, we formalize these in terms of mathematical operations on probability distributions.

Background: probabilistic models

First, we express everything we know about the variables in a problem and the relationships between them in the form of a probabilistic model, or a joint probability distribution over a collection of random variables. A model has observations \({\bf x}\) and latent random variables \({\bf z}\) as well as parameters \(\theta\). It usually has a joint density function of the form

\[p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})\]

The distribution over latent variables \(p_{\theta}({\bf z})\) in this formula is called the prior, and the distribution over observed variables given latent variables \(p_{\theta}({\bf x}|{\bf z})\) is called the likelihood.

We typically require that the various conditional probability distributions \(p_i\) that make up a model \(p_{\theta}({\bf x}, {\bf z})\) have the following properties (generally satisfied by the distributions available in Pyro and PyTorch Distributions):

  • we can efficiently sample from each \(p_i\)

  • we can efficiently compute the pointwise probability density \(p_i\)

  • \(p_i\) is differentiable w.r.t. the parameters \(\theta\)

Probabilistic models are often depicted in a standard graphical notation for visualization and communication, summarized below, although it is possible in Pyro to represent models that do not have a fixed graphical structure. In models with lots of repetition, it is convenient to use plate notation, so called because it is shown graphically as a rectangular “plate” around variables to indicate multiple independent copies of the random variables inside.

Screen%20Shot%202021-12-13%20at%201.31.34%20PM.png

Background: inference, learning and evaluation

Once we have specified a model, Bayes’ rule tells us how to use it to perform inference, or draw conclusions about latent variables from data, by computing the posterior distribution over \(\bf z\):

\[ p_{\theta}({\bf z} | {\bf x}) = \frac{p_{\theta}({\bf x} , {\bf z})}{ \int \! d{\bf z}\; p_{\theta}({\bf x} , {\bf z}) }\]

To check the results of modeling and inference, we would like to know how well a model fits observed data \(x\), which we can quantify with the evidence or marginal likelihood

\[p_{\theta}({\bf x}) = \int \! d{\bf z}\; p_{\theta}({\bf x} , {\bf z})\]

and also to make predictions for new data, which we can do with the posterior predictive distribution

\[p_{\theta}(x' | {\bf x}) = \int \! d{\bf z}\; p_{\theta}(x' | {\bf z}) p_{\theta}({\bf z} | {\bf x})\]

Finally, it is often desirable to learn the parameters \(\theta\) of our models from observed data \(x\), which we can do by maximizing the marginal likelihood:

\[\theta_{\rm{max}} = \rm{argmax}_\theta p_{\theta}({\bf x}) = \rm{argmax}_\theta \int \! d{\bf z}\; p_{\theta}({\bf x} , {\bf z})\]

Example: Geography and national income

The following example is adapted from Chapter 7 of the excellent book Statistical Rethinking by Richard McElreath, which readers are encouraged to review for an accessible introduction to the broader practice of Bayesian data analysis (Pyro code for all chapters is available).

We would like to explore the relationship between topographic heterogeneity of a nation as measured by the Terrain Ruggedness Index (variable rugged in the dataset) and its GDP per capita. In particular, it was noted by the authors of the original paper (“Ruggedness: The blessing of bad geography in Africa”) that terrain ruggedness or bad geography is related to poorer economic performance outside of Africa, but rugged terrains have had a reverse effect on income for African nations.

Let us look at the data and investigate this relationship. We will be focusing on three features from the dataset:

  • rugged: quantifies the Terrain Ruggedness Index;

  • cont_africa: whether the given nation is in Africa;

  • rgdppc_2000: Real GDP per capita for the year 2000;

[44]:
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]

The response variable GDP is highly skewed, so we will log-transform it before proceeding.

[45]:
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

We then convert the Numpy array behind this dataframe to a torch.Tensor for analysis with PyTorch and Pyro.

[46]:
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

Visualizing the data suggests that there is indeed a possible relationship between ruggedness and GDP, but that further analysis will be needed to confirm it. We will see how to do this in Pyro via Bayesian linear regression.

[47]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = df[df["cont_africa"] == 1]
non_african_nations = df[df["cont_africa"] == 0]
sns.scatterplot(x=non_african_nations["rugged"],
                y=non_african_nations["rgdppc_2000"],
                ax=ax[0])
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
sns.scatterplot(x=african_nations["rugged"],
                y=african_nations["rgdppc_2000"],
                ax=ax[1])
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");
_images/intro_long_20_0.png

Models in Pyro

Probabilistic models in Pyro are specified as Python functions model(*args, **kwargs) that generate observed data from latent variables using special primitive functions whose behavior can be changed by Pyro’s internals depending on the high-level computation being performed.

Specifically, the different mathematical pieces of model() are encoded via the mapping:

  1. latent random variables \(z\) \(\Longleftrightarrow\) pyro.sample

  2. observed random variables \(x\) \(\Longleftrightarrow\) pyro.sample with the obs keyword argument

  3. learnable parameters \(\theta\) \(\Longleftrightarrow\) pyro.param

  4. plates \(\Longleftrightarrow\) pyro.plate context managers

We examine each of these components in detail below in the context of a first Pyro model for linear regression.

Example model: Maximum-likelihood linear regression

If we write out the formula for our linear regression predictor \(\beta X + \alpha\) as a Python expression, we get the following:

mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

To build this up into a full probabilistic model for our dataset, we need to make the regression coefficients \(\alpha\) and \(\beta\) learnable parameters and add observation noise around the predicted mean. We can express this using the Pyro primitives introduced above and visualize the resulting model using pyro.render_model():

[48]:
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

def simple_model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.param("a", lambda: torch.randn(()))
    b_a = pyro.param("bA", lambda: torch.randn(()))
    b_r = pyro.param("bR", lambda: torch.randn(()))
    b_ar = pyro.param("bAR", lambda: torch.randn(()))
    sigma = pyro.param("sigma", lambda: torch.ones(()), constraint=constraints.positive)

    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

pyro.render_model(simple_model, model_args=(is_cont_africa, ruggedness, log_gdp), render_distributions=True)
[48]:
_images/intro_long_24_0.svg

The above plot does not show the model parameters a, b_a, b_r, b_ar, and sigma. We can set render_params=True to render the model parameters.

[49]:
pyro.render_model(simple_model, model_args=(is_cont_africa, ruggedness, log_gdp), render_distributions=True, render_params=True)
[49]:
_images/intro_long_26_0.svg

The argument render_distributions = True will show constraints on the parameters. For example, sigma is a standard deviation that should be non-negative. Thus, its constraint is shown as “sigma \(\in\) GreaterThan(lower_bound=0.0)”.

Learning the parameters of simple_model would constitute maximum likelihood estimation and produce point estimates of the regression coefficients. However, in this example our data visualization suggests we should not be too confident in any single value for the regression coefficients. By contrast, a fully Bayesian approach would produce uncertainty estimates over different possible parameter values as well as the model’s predictions.

Before making a Bayesian version of our linear model, let’s pause and take a closer look at this first piece of Pyro code.

Background: the pyro.sample primitive

Probabilistic programs in Pyro are built up around samples from primitive probability distributions, marked by pyro.sample:

def sample(
    name: str,
    fn: pyro.distributions.Distribution,
    *,
    obs: typing.Optional[torch.Tensor] = None,
    infer: typing.Optional[dict] = None
) -> torch.Tensor:
    ...

In our model simple_model above, the line

return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

can represent either a latent variable or an observed variable depending on whether simple_model is given a log_gdp value. When log_gdp is not provided and "obs" is latent, it is equivalent (ignoring the pyro.plate context for now by assuming len(ruggedness) == 1) to calling the distribution’s underlying .sample method:

return dist.Normal(mean, sigma).sample()

This interpretation is behind occasional references to Pyro programs as stochastic functions, a rather obscure term used in some of Pyro’s older documentation.

When simple_model is given a log_gdp argument and "obs" is observed, the pyro.sample statement

return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

will always return log_gdp:

return log_gdp

However, note that when any sample statement is observed, the cumulative effect of every other sample statement in a model changes following Bayes’ rule; it is the job of Pyro’s inference algorithms to “run the program backwards” and assign mathematically consistent values to all pyro.sample statements in a model.

A reasonable question to ask at this point is why pyro.sample and other primitives must have names. Names are leveraged by users and by Pyro’s internals to separate specifications of model, observations and inference algorithms, a key selling point of probabilistic programming languages. To see an example of this, we can look at the higher-order primitive pyro.condition which addresses the problem of writing many queries against a single Pyro model:

def condition(
    model: Callable[..., T],
    data: Dict[str, torch.Tensor]
) -> Callable[..., T]:
    ...

pyro.condition takes a model and a (possibly empty) dictionary mapping names to observed values and passes each observation to the pyro.sample statement indicated by its name. In the context of our example simple_model, we could remove log_gdp as an argument and replace it with a simpler interface

def simpler_model(is_cont_africa, ruggedness): ...

conditioned_model = pyro.condition(simpler_model, data={"obs": log_gdp})

where conditioned_model is equivalent to

conditioned_model = functools.partial(simple_model, log_gdp=log_gdp)

Background: the pyro.param primitive

The next primitive used in our model, pyro.param, is a frontend for reading from and writing to Pyro’s key-value parameter store:

def param(
    name: str,
    init: Optional[Union[torch.Tensor, Callable[..., torch.Tensor]]] = None,
    *,
    constraint: torch.distributions.constraints.Constraint = constraints.real
) -> torch.Tensor:
    ...

Like pyro.sample, pyro.param is always called with a name as its first argument. The first time pyro.param is called with a particular name, it stores the initial value specified by the second argument init in the parameter store and then returns that value. After that, when it is called with that name, it returns the value from the parameter store regardless of any other arguments. After a parameter has been initialized, it is no longer necessary to specify init to retrieve its value (e.g. pyro.param("a")).

The second argument, init, can be either a torch.Tensor or a function that takes no arguments and returns a tensor. The second form is useful because it avoids repeatedly constructing initial values that are only used the first time a model is run.

Unlike PyTorch’s torch.nn.Parameters, parameters in Pyro can be explicitly constrained to various subsets of \(\mathbb{R}^n\), an important feature because many elementary probability distributions have parameters with restricted domains. For example, the scale parameter of a Normal distribution must be positive. The optional third argument to pyro.param, constraint, is a torch.distributions.constraints.Constraint object stored when a parameter is initialized; constraints are reapplied after every update. Pyro ships with a large number of predefined constraints.

pyro.param values persist across model calls, unless the parameter store is updated by an optimization algorithm or cleared via pyro.clear_param_store(). Unlike pyro.sample, pyro.param can be called with the same name multiple times in a model; every call with the same name will return the same value. The global parameter store itself is accessible by calling pyro.get_param_store().

In our model simple_model above, the statement

a = pyro.param("a", lambda: torch.randn(()))

is similar conceptually to the code below, but with some additional tracking, serialization and constraint management functionality.

simple_param_store = {}
...
def simple_model():
    a_init = lambda: torch.randn(())
    a = simple_param_store["a"] if "a" in simple_param_store else a_init()
    ...

While this introductory tutorial uses pyro.param for parameter management, Pyro is also compatible with PyTorch’s familiar torch.nn.Module API via pyro.nn.PyroModule.

Background: the pyro.plate primitive

pyro.plate is Pyro’s formal encoding of plate notation, widely used in probabilistic machine learning to simplify visualization and analysis of models with lots of conditionally independent and identically distributed random variables.

def plate(
    name: str,
    size: int,
    *,
    dim: Optional[int] = None,
    **other_kwargs
) -> contextlib.AbstractContextManager:
    ...

Conceptually, pyro.plate statements are equivalent to a for-loop. In simple_model, we could replace the lines

with pyro.plate("data", len(ruggedness)):
    return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

with a Python for-loop

result = torch.empty_like(ruggedness)
for i in range(len(ruggedness)):
    result[i] = pyro.sample(f"obs_{i}", dist.Normal(mean, sigma), obs=log_gdp[i] if log_gdp is not None else None)
return result

When the number of repeated variables (len(ruggedness) in this case) is large, using a Python loop would be quite slow. Since each iteration in the loop is independent of the others, pyro.plate uses PyTorch’s array broadcasting to perform the iterations in parallel in a single vectorized operation, as in this equivalent vectorized code:

mean = mean.unsqueeze(-1).expand((len(ruggedness),))
sigma = sigma.unsqueeze(-1).expand((len(ruggedness),))
return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

Practically speaking, when writing Pyro programs pyro.plate is most useful as a tool for vectorization. However, as described in part 2 of the SVI tutorial, it is also useful for managing data subsampling and as a signal to inference algorithms that certain variables are independent.

Example: from maximum-likelihood regression to Bayesian regression

In order to make our linear regression model Bayesian, we need to specify prior distributions on the parameters \(\alpha \in \mathbb{R}\) and \(\beta \in \mathbb{R}^3\) (expanded here into scalars b_a, b_r, and b_ar). These are probability distributions that represent our beliefs prior to observing any data about reasonable values for \(\alpha\) and \(\beta\). We will also add a random scale parameter \(\sigma\) that controls the observation noise.

Expressing this Bayesian model for linear regression in Pyro is very intuitive: we simply replace each of the pyro.param statements with pyro.sample statements equipped with Pyro Distribution objects describing prior beliefs about each parameter.

For the constant term \(\alpha\), we use a Normal prior with a large standard deviation to indicate our relative lack of prior knowledge about baseline GDP. For the other regression coefficients, we use standard Normal priors (centered at 0) to represent our lack of a priori knowledge of whether the relationship between covariates and GDP is positive or negative. For the observation noise \(\sigma\), we use a flat prior bounded below by 0 because this value must be positive to be a valid standard deviation.

[50]:
def model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))

    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

pyro.render_model(model, model_args=(is_cont_africa, ruggedness, log_gdp), render_distributions=True)
[50]:
_images/intro_long_34_0.svg

Inference in Pyro

Background: variational inference

Each of the computations from the introduction (the posterior distribution, the marginal likelihood and the posterior predictive distribution) requires performing integrals that are often impossible or computationally intractable.

While Pyro includes support for many different exact and approximate inference algorithms, the best-supported is variational inference, which offers a unified scheme for finding \(\theta_{\rm{max}}\) and computing a tractable approximation \(q_{\phi}({\bf z})\) to the true, unknown posterior \(p_{\theta_{\rm{max}}}({\bf z} | {\bf x})\) by converting the intractable integrals into optimization of a functional of \(p\) and \(q\). The figure below depicts this process conceptually, while a more comprehensive mathematical introduction is available in the SVI tutorials.

Most probability distributions (the light ellipse in the figure below), especially those corresponding to Bayesian posterior distributions, are too complex to represent directly, so we must define a smaller subspace, indexed by real-valued parameters \(\phi\), of distributions \(q_{\phi}({\bf z})\) that are by construction guaranteed to be easy to sample from (the dark circle in the figure below) but may not include the true posterior distribution \(p_{\theta}({\bf z} | {\bf x})\) (the red star in the figure below).

Variational inference approximates the true posterior by searching the space of variational distributions to find one that is most similar to the true posterior (the yellow star in the figure below) according to some measure of distance or divergence (the black arrow in the figure below).

Screen%20Shot%202021-12-10%20at%203.13.09%20PM.png

However, there are many different ways to measure distance or divergence between probability distributions. Which one should we choose? As indicated by the figure, a theoretically appealing choice is the Kullback-Leibler divergence \(KL(q_{\phi}({\bf z}) || p_{\theta}({\bf z} | {\bf x}))\), but computing this directly requires knowing the true posterior ahead of time, which would defeat the purpose.

What’s more, we are interested in optimizing this divergence, which might sound even harder, but in fact it is possible to use Bayes’ theorem to rewrite the definition of \(KL(q_{\phi}({\bf z}) || p_{\theta}({\bf z} | {\bf x}))\) as the difference between an intractable constant that does not depend on \(q_{\phi}\) and a tractable term called the evidence lower bound (ELBO), defined below. Maximizing this tractable term will therefore produce the same solution as minimizing the original KL-divergence.

Background: “guide” programs as flexible approximate posteriors

In variational inference, we introduce a parameterized distribution \(q_{\phi}({\bf z})\) to approximate the true posterior, where \(\phi\) are known as the variational parameters. This distribution is called the variational distribution in much of the literature, and in the context of Pyro it’s called the guide (one syllable instead of nine!).

Just like the model, the guide is encoded as a Python program guide() that contains pyro.sample and pyro.param statements. It does not contain observed data, since the guide needs to be a properly normalized distribution so that it is easy to sample from. Note that Pyro enforces that model() and guide() should take the same arguments. Allowing guides to be arbitrary Pyro programs opens up the possibility of writing guide families that capture more of the problem-specific structure of the true posterior, expanding the search space in only helpful directions, as depicted schematically in the figure below.