Poutine: A Guide to Programming with Effect Handlers in Pyro

Note to readers: This tutorial is a guide to the API details of Pyro’s effect handling library, Poutine. We recommend readers first orient themselves with the simplified minipyro.py which contains a minimal, readable implementation of Pyro’s runtime and the effect handler abstraction described here. Pyro’s effect handler library is more general than minipyro’s but also contains more layers of indirection; it helps to read them side-by-side.

[1]:
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from pyro.poutine.runtime import effectful

pyro.set_rng_seed(101)

Introduction

Inference in probabilistic programming involves manipulating or transforming probabilistic programs written as generative models. For example, nearly all approximate inference algorithms require computing the unnormalized joint probability of values of latent and observed variables under a generative model.

Consider the following example model:

[2]:
def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

This model defines a joint probability distribution over "weight" and "measurement":

\[{\sf weight} \, | \, {\sf guess} \sim \cal {\sf Normal}({\sf guess}, 1)\]
\[{\sf measurement} \, | \, {\sf guess}, {\sf weight} \sim {\sf Normal}({\sf weight}, 0.75)\]

If we had access to the inputs and outputs of each pyro.sample site, we could compute their log-joint:

logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()

However, the way we wrote scale above does not seem to expose these intermediate distribution objects, and rewriting it to return them would be intrusive and would violate the separation of concerns between models and inference algorithms that a probabilistic programming language like Pyro is designed to enforce.

To resolve this conflict and facilitate inference algorithm development, Pyro exposes Poutine, a library of effect handlers, or composable building blocks for examining and modifying the behavior of Pyro programs. Most of Pyro’s internals are implemented on top of Poutine.

A first look at Poutine: Pyro’s library of algorithmic building blocks

Effect handlers, a common abstraction in the programming languages community, give nonstandard interpretations or side effects to the behavior of particular statements in a programming language, like pyro.sample or pyro.param. For background reading on effect handlers in programming language research, see the optional “References” section at the end of this tutorial.

Rather than reviewing more definitions, let’s look at a first example that addresses the problem above: we can compose two existing effect handlers, poutine.condition (which sets output values of pyro.sample statements) and poutine.trace (which records the inputs, distributions, and outputs of pyro.sample statements), to concisely define a new effect handler that computes the log-joint:

[3]:
def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5)))
tensor(-3.0203)

That snippet is short, but still somewhat opaque - poutine.condition, poutine.trace, and trace.log_prob_sum are all black boxes. Let’s remove a layer of boilerplate from poutine.condition and poutine.trace and explicitly implement what trace.log_prob_sum is doing:

[4]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
tensor(-3.0203)

This makes things a little more clear: we can now see that poutine.trace and poutine.condition are wrappers for context managers that presumably communicate with the model through something inside pyro.sample. We can also see that poutine.trace produces a data structure (a Trace) containing a dictionary whose keys are sample site names and values are dictionaries containing the distribution ("fn") and output ("value") at each site, and that the output values at each site are exactly the values specified in data.

Finally, TraceMessenger and ConditionMessenger are Pyro effect handlers, or Messengers: stateful context manager objects that are placed on a global stack and send messages (hence the name) up and down the stack at each effectful operation, like a pyro.sample call. A Messenger is placed at the bottom of the stack when its __enter__ method is called, i.e. when it is used in a “with” statement.

We’ll look at this process in more detail later in this tutorial. For a simplified implementation in only a few lines of code, see pyro.contrib.minipyro.

Implementing new effect handlers with the Messenger API

Although it’s easiest to build new effect handlers by composing the existing ones in pyro.poutine, implementing a new effect as a pyro.poutine.messenger.Messenger subclass is actually fairly straightforward. Before diving into the API, let’s look at another example: a version of our log-joint computation that performs the sum while the model is executing. We’ll then review what each part of the example is actually doing.

[5]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ is syntactic sugar for using Messengers as higher-order functions.
    # Messenger already defines __call__, but we re-define it here
    # for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__enter__()
        # in their __enter__ methods
        return super().__enter__()

    # __exit__ takes the same arguments in all Python context managers
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__exit__ method
        # in their __exit__ methods.
        return super().__exit__(exc_type, exc_value, traceback)

    # _pyro_sample will be called once per pyro.sample site.
    # It takes a dictionary msg containing the name, distribution,
    # observation or sample value, and other metadata from the sample site.
    def _pyro_sample(self, msg):
        # Any unobserved random variables will trigger this assertion.
        # In the next section, we'll learn how to also handle sampled values.
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # Since we've observed a value for this site, we set the "is_observed" flag to True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
tensor(-3.0203)
tensor(-3.0203)

A convenient bit of boilerplate that allows the use of LogJointMessenger as a context manager, decorator, or higher-order function is the following. Most of the existing effect handlers in pyro.poutine, including poutine.trace and poutine.condition which we used earlier, are Messengers wrapped this way in pyro.poutine.handlers.

[6]:
def log_joint(model=None, cond_data=None):
    msngr = LogJointMessenger(cond_data=cond_data)
    return msngr(model) if model is not None else msngr

scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))
tensor(-3.0203)

The Messenger API in more detail

Our LogJointMessenger implementation has three important methods: __enter__, __exit__, and _pyro_sample.

__enter__ and __exit__ are special methods needed by any Python context manager. When implementing new Messenger classes, if we override __enter__ and __exit__, we always need to call the base Messenger’s __enter__ and __exit__ methods for the new Messenger to be applied correctly.

The last method LogJointMessenger._pyro_sample, is called once at each sample site. It reads and modifies a message, which is a dictionary containing the sample site’s name, distribution, sampled or observed value, and other metadata. We’ll examine the contents of a message in more detail in the next section.

Instead of _pyro_sample, a generic Messenger actually contains two methods that are called once per operation where side effects are performed: 1. _process_message modifies a message and sends the result to the Messenger just above on the stack 2. _postprocess_message modifies a message and sends the result to the next Messenger down on the stack. It is always called after all active Messengers have had their _process_message method applied to the message.

Although custom Messengers can override _process_message and _postprocess_message, it’s convenient to avoid requiring all effect handlers to be aware of all possible effectful operation types. For this reason, by default Messenger._process_message will use msg["type"] to dispatch to a corresponding method Messenger._pyro_<type>, e.g. Messenger._pyro_sample as in LogJointMessenger. Just as exception handling code ignores unhandled exception types, this allows Messengers to simply forward operations they don’t know how to handle up to the next Messenger in the stack:

class Messenger:
    ...
    def _process_message(self, msg):
        method_name = "_pyro_{}".format(msg["type"])  # e.g. _pyro_sample when msg["type"] == "sample"
        if hasattr(self, method_name):
            getattr(self, method_name)(msg)
    ...

Interlude: the global Messenger stack

See pyro.contrib.minipyro for an end-to-end implementation of the mechanism in this section.

The order in which Messengers are applied to an operation like a pyro.sample statement is determined by the order in which their __enter__ methods are called. Messenger.__enter__ appends a Messenger to the end (the bottom) of the global handler stack:

class Messenger:
    ...
    # __enter__ pushes a Messenger onto the stack
    def __enter__(self):
        ...
        _PYRO_STACK.append(self)
        ...

    # __exit__ removes a Messenger from the stack
    def __exit__(self, ...):
        ...
        assert _PYRO_STACK[-1] is self
        _PYRO_STACK.pop()
        ...

pyro.poutine.runtime.apply_stack then traverses the stack twice at each operation, first from bottom to top to apply each _process_message and then from top to bottom to apply each _postprocess_message:

def apply_stack(msg):  # simplified
    for handler in reversed(_PYRO_STACK):
        handler._process_message(msg)
    ...
    default_process_message(msg)
    ...
    for handler in _PYRO_STACK:
        handler._postprocess_message(msg)
    ...
    return msg

Returning to the LogJointMessenger example

The second method _postprocess_message is necessary because some effects can only be applied after all other effect handlers have had a chance to update the message once. In the case of LogJointMessenger, other effects, like enumeration, may modify a sample site’s value or distribution (msg["value"] or msg["fn"]), so we move the log-probability computation to a new method, _pyro_post_sample, which is called by _postprocess_message (via a dispatch mechanism like the one used by _process_message) at each sample site after all active handlers’ _pyro_sample methods have been applied:

[7]:
class LogJointMessenger2(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True

    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()


with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)
tensor(-3.0203)

Inside the messages sent by Messengers

As the previous two examples mentioned, the actual messages sent up and down the stack are dictionaries with a particular set of keys. Consider the following sample statement:

pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)

This sample statement is converted into an initial message before any effects are applied, and each effect handler’s _process_message and _postprocess_message may update fields in place or add new fields. We write out the full initial message here for completeness:

msg = {
    # The following fields contain the name, inputs, function, and output of a site.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] will eventually contain the value returned by pyro.sample
    "is_observed": False,  # because obs=None by default; only used by sample sites
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    # This field typically contains metadata needed or stored by a particular inference algorithm
    "infer": {"enumerate": "parallel"},
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}

Note that when we use poutine.trace or TraceMessenger as in our first two versions of make_log_joint, the contents of msg are exactly the information stored in the trace for each sample and param site.

Implementing inference algorithms with existing effect handlers: examples

It turns out that many inference operations, like our first version of make_log_joint above, have strikingly short implementations in terms of existing effect handlers in pyro.poutine.

Example: Variational inference with a Monte Carlo ELBO

For example, here is an implementation of variational inference with a Monte Carlo ELBO that uses poutine.trace, poutine.condition, and poutine.replay. This is very similar to the simple ELBO in pyro.contrib.minipyro.

[8]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
    conditioned_model = poutine.condition(model, data=batch)

    # we'll approximate the expectation in the ELBO with a single sample:
    # first, we run the guide forward unmodified and record values and distributions
    # at each sample site using poutine.trace
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

    # we use poutine.replay to set the values of latent variables in the model
    # to the values sampled above by our guide, and use poutine.trace
    # to record the distributions that appear at each sample site in in the model
    model_trace = poutine.trace(
        poutine.replay(conditioned_model, trace=guide_trace)
    ).get_trace(*args, **kwargs)

    elbo = 0.
    for name, node in model_trace.nodes.items():
        if node["type"] == "sample":
            elbo = elbo + node["fn"].log_prob(node["value"]).sum()
            if not node["is_observed"]:
                elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
    return -elbo

We use poutine.trace and poutine.block to record pyro.param calls for optimization:

[9]:
def train(model, guide, data):
    optimizer = pyro.optim.Adam({})
    for batch in data:
        # this poutine.trace will record all of the parameters that appear in the model and guide
        # during the execution of monte_carlo_elbo
        with poutine.trace() as param_capture:
            # we use poutine.block here so that only parameters appear in the trace above
            with poutine.block(hide_fn=lambda node: node["type"] != "param"):
                loss = monte_carlo_elbo(model, guide, batch)

        loss.backward()
        params = set(node["value"].unconstrained()
                     for node in param_capture.trace.nodes.values())
        optimizer.step(params)
        pyro.infer.util.zero_grads(params)

Example: exact inference via sequential enumeration

Here is an example of a very different inference algorithm–exact inference via enumeration–implemented with pyro.poutine. A complete explanation of this algorithm is beyond the scope of this tutorial and may be found in Chapter 3 of the short online book Design and Implementation of Probabilistic Programming Languages. This example uses poutine.queue, itself implemented using poutine.trace, poutine.replay, and poutine.block, to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site:

[10]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):

    from six.moves import queue  # queue data structures
    q = queue.Queue()  # Instantiate a first-in first-out queue
    q.put(poutine.Trace())  # seed the queue with an empty trace

    # as before, we fix the values of observed random variables with poutine.condition
    # assuming data is a dictionary whose keys are names of sample sites in model
    conditioned_model = poutine.condition(model, data=data)

    # we wrap the conditioned model in a poutine.queue,
    # which repeatedly pushes and pops partially completed executions from a Queue()
    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model
    enum_model = poutine.queue(conditioned_model, queue=q)

    # actually perform the enumeration by repeatedly tracing enum_model
    # and accumulate samples and trace log-probabilities for postprocessing
    samples, log_weights = [], []
    while not q.empty():
        trace = poutine.trace(enum_model).get_trace()
        samples.append(trace.nodes[site_name]["value"])
        log_weights.append(trace.log_prob_sum())

    # we take the samples and log-joints and turn them into a histogram:
    samples = torch.stack(samples, 0)
    log_weights = torch.stack(log_weights, 0)
    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
    return dist.Empirical(samples, log_weights)

(Note that sequential_discrete_marginal is very general, but is also quite slow. For high-performance parallel enumeration that applies to a less general class of models, see the enumeration tutorial.)

Example: implementing lazy evaluation with the Messenger API

Now that we’ve learned more about the internals of Messenger, let’s use it to implement a slightly more complicated effect: lazy evaluation. We first define a LazyValue class that we will use to build up a computation graph:

[11]:
class LazyValue:
    def __init__(self, fn, *args, **kwargs):
        self._expr = (fn, args, kwargs)
        self._value = None

    def __str__(self):
        return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))

    def evaluate(self):
        if self._value is None:
            fn, args, kwargs = self._expr
            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg
                         for arg in args)
            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v
                      for k, v in kwargs.items()}
            self._value = fn(*args, **kwargs)
        return self._value

With LazyValue, implementing lazy evaluation as a Messenger compatible with other effect handlers is suprisingly easy. We just make each msg["value"] a LazyValue and introduce a new operation type "apply" for deterministic operations:

[12]:
class LazyMessenger(pyro.poutine.messenger.Messenger):
    def _process_message(self, msg):
        if msg["type"] in ("apply", "sample") and not msg["done"]:
            msg["done"] = True
            msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])

Finally, just like torch.autograd overloads torch tensor operations to record an autograd graph, we need to wrap any operations we’d like to be lazy. We’ll use pyro.poutine.runtime.effectful as a decorator to expose these operations to LazyMessenger. effectful constructs a message much like the one above and sends it up and down the effect handler stack, but allows us to set the type (in this case, to "apply" instead of "sample") so that these operations aren’t mistaken for sample statements by other effect handlers like TraceMessenger:

[13]:
@effectful(type="apply")
def add(x, y):
    return x + y

@effectful(type="apply")
def mul(x, y):
    return x * y

@effectful(type="apply")
def sigmoid(x):
    return torch.sigmoid(x)

@effectful(type="apply")
def normal(loc, scale):
    return dist.Normal(loc, scale)

Applied to another model:

[14]:
def biased_scale(guess):
    weight = pyro.sample("weight", normal(guess, 1.))
    tolerance = pyro.sample("tolerance", normal(0., 0.25))
    return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))

with LazyMessenger():
    v = biased_scale(8.5)
    print(v)
    print(v.evaluate())
((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )
tensor(6.5436)

Together with other effect handlers like TraceMessenger and ConditionMessenger, with which it freely composes, LazyMessenger demonstrates how to use Poutine to quickly and concisely implement state-of-the-art PPL techniques like delayed sampling with Rao-Blackwellization.

References: algebraic effects and handlers in programming language research

This section contains some references to PL papers for readers interested in this direction.

Algebraic effects and handlers, which were developed starting in the early 2000s and are a subject of active research in the programming languages community, are a versatile abstraction for building modular implementations of nonstandard interpreters of particular statements in a programming language, like pyro.sample or pyro.param. They were originally introduced to address the difficulty of composing nonstandard interpreters implemented with monads and monad transformers.

  • For an accessible introduction to the effect handlers literature, see the excellent review/tutorial paper “Handlers in Action” by Ohad Kammar, Sam Lindley, and Nicolas Oury, and the references therein.

  • Algebraic effect handlers were originally introduced by Gordon Plotkin and Matija Pretnar in the paper “Handlers of Algebraic Effects”.

  • A useful mental model of effect handlers is as exception handlers that are capable of resuming computation in the try block after raising an exception and performing some processing in the except block. This metaphor is explored further in the experimental programming language Eff and its companion paper “Programming with Algebraic Effects and Handlers” by Andrej Bauer and Matija Pretnar.

  • Most effect handlers in Pyro are “linear,” meaning that they only resume once per effectful operation and do not alter the order of execution of the original program. One exception is poutine.queue, which uses an inefficient implementation strategy for multiple resumptions like the one described for delimited continuations in the paper “Capturing the Future by Replaying the Past” by James Koppel, Gabriel Scherer, and Armando Solar-Lezama.

  • More efficient implementation strategies for effect handlers in mainstream programming languages like Python or JavaScript is an area of active research. One promising line of work involves selective continuation-passing style transforms as in the paper “Type-Directed Compilation of Row-Typed Algebraic Effects” by Daan Leijen.