# Poutine: A Guide to Programming with Effect Handlers in Pyro¶

**Note to readers**: This tutorial is a guide to the API details of
Pyro’s effect handling library,
Poutine. We recommend
readers first orient themselves with the simplified
minipyro.py
which contains a minimal, readable implementation of Pyro’s runtime and
the effect handler abstraction described here. Pyro’s effect handler
library is more general than minipyro’s but also contains more layers of
indirection; it helps to read them side-by-side.

```
In [1]:
```

```
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.poutine.runtime import effectful
pyro.set_rng_seed(101)
```

## Introduction¶

Inference in probabilistic programming involves manipulating or transforming probabilistic programs written as generative models. For example, nearly all approximate inference algorithms require computing the unnormalized joint probability of values of latent and observed variables under a generative model.

Consider the following example model from the introductory inference tutorial:

```
In [2]:
```

```
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))
```

This model defines a joint probability distribution over `"weight"`

and `"measurement"`

:

If we had access to the inputs and outputs of each `pyro.sample`

site,
we could compute their log-joint:

```
logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()
```

However, the way we wrote `scale`

above does not seem to expose these
intermediate distribution objects, and rewriting it to return them would
be intrusive and would violate the separation of concerns between models
and inference algorithms that a probabilistic programming language like
Pyro is designed to enforce.

To resolve this conflict and facilitate inference algorithm development,
Pyro exposes Poutine, a
library of *effect handlers*, or composable building blocks for
examining and modifying the behavior of Pyro programs. Most of Pyro’s
internals are implemented on top of Poutine.

## A first look at Poutine: Pyro’s library of algorithmic building blocks¶

Effect handlers, a common abstraction in the programming languages
community, give *nonstandard interpretations* or *side effects* to the
behavior of particular statements in a programming language, like
`pyro.sample`

or `pyro.param`

. For background reading on effect
handlers in programming language research, see the optional “References”
section at the end of this tutorial.

Rather than reviewing more definitions, let’s look at a first example
that addresses the problem above: we can compose two existing effect
handlers, `poutine.condition`

(which sets output values of
`pyro.sample`

statements) and `poutine.trace`

(which records the
inputs, distributions, and outputs of `pyro.sample`

statements), to
concisely define a new effect handler that computes the log-joint:

```
In [3]:
```

```
def make_log_joint(model):
def _log_joint(cond_data, *args, **kwargs):
conditioned_model = poutine.condition(model, data=cond_data)
trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
return trace.log_prob_sum()
return _log_joint
scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
```

```
tensor(-3.0203)
```

That snippet is short, but still somewhat opaque -
`poutine.condition`

, `poutine.trace`

, and `trace.log_prob_sum`

are
all black boxes. Let’s remove a layer of boilerplate from
`poutine.condition`

and `poutine.trace`

and explicitly implement
what `trace.log_prob_sum`

is doing:

```
In [4]:
```

```
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger
def make_log_joint_2(model):
def _log_joint(cond_data, *args, **kwargs):
with TraceMessenger() as tracer:
with ConditionMessenger(data=cond_data):
model(*args, **kwargs)
trace = tracer.trace
logp = 0.
for name, node in trace.nodes.items():
if node["type"] == "sample":
if node["is_observed"]:
assert node["value"] is cond_data[name]
logp = logp + node["fn"].log_prob(node["value"]).sum()
return logp
return _log_joint
scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
```

```
tensor(-3.0203)
```

This makes things a little more clear: we can now see that
`poutine.trace`

and `poutine.condition`

are wrappers for context
managers that presumably communicate with the model through something
inside `pyro.sample`

. We can also see that `poutine.trace`

produces
a data structure (a
``Trace`

<http://docs.pyro.ai/en/dev/poutine.html#trace>`__)
containing a dictionary whose keys are `sample`

site names and values
are dictionaries containing the distribution (`"fn"`

) and output
(`"value"`

) at each site, and that the output values at each site are
exactly the values specified in `data`

.

Finally, `TraceMessenger`

and `ConditionMessenger`

are Pyro effect
handlers, or `Messenger`

s: stateful context manager objects that are
placed on a global stack and send messages (hence the name) up and down
the stack at each effectful operation, like a `pyro.sample`

call. A
`Messenger`

is placed at the bottom of the stack when its
`__enter__`

method is called, i.e. when it is used in a “with”
statement.

We’ll look at this process in more detail later in this tutorial. For a simplified implementation in only a few lines of code, see pyro.contrib.minipyro.

## Implementing new effect handlers with the `Messenger`

API¶

Although it’s easiest to build new effect handlers by composing the
existing ones in `pyro.poutine`

, implementing a new effect as a
`pyro.poutine.messenger.Messenger`

subclass is actually fairly
straightforward. Before diving into the API, let’s look at another
example: a version of our log-joint computation that performs the sum
while the model is executing. We’ll then review what each part of the
example is actually doing.

```
In [5]:
```

```
class LogJointMessenger(poutine.messenger.Messenger):
def __init__(self, cond_data):
self.data = cond_data
# __call__ is syntactic sugar for using Messengers as higher-order functions.
# Messenger already defines __call__, but we re-define it here
# for exposition and to change the return value:
def __call__(self, fn):
def _fn(*args, **kwargs):
with self:
fn(*args, **kwargs)
return self.logp.clone()
return _fn
def __enter__(self):
self.logp = torch.tensor(0.)
# All Messenger subclasses must call the base Messenger.__enter__()
# in their __enter__ methods
return super(LogJointMessenger, self).__enter__()
# __exit__ takes the same arguments in all Python context managers
def __exit__(self, exc_type, exc_value, traceback):
self.logp = torch.tensor(0.)
# All Messenger subclasses must call the base Messenger.__exit__ method
# in their __exit__ methods.
return super(LogJointMessenger, self).__exit__(exc_type, exc_value, traceback)
# _pyro_sample will be called once per pyro.sample site.
# It takes a dictionary msg containing the name, distribution,
# observation or sample value, and other metadata from the sample site.
def _pyro_sample(self, msg):
assert msg["name"] in self.data
msg["value"] = self.data[msg["name"]]
# Since we've observed a value for this site, we set the "is_observed" flag to True
# This tells any other Messengers not to overwrite msg["value"] with a sample.
msg["is_observed"] = True
self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp.clone())
scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
```

```
tensor(-3.0203)
tensor(-3.0203)
```

A convenient bit of boilerplate that allows the use of
`LogJointMessenger`

as a context manager, decorator, or higher-order
function is the following. Most of the existing effect handlers in
`pyro.poutine`

, including `poutine.trace`

and `poutine.condition`

which we used earlier, are `Messenger`

s wrapped this way in
`pyro.poutine.handlers`

.

```
In [6]:
```

```
def log_joint(model=None, cond_data=None):
msngr = LogJointMessenger(cond_data=cond_data)
return msngr(model) if model is not None else msngr
scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))
```

```
tensor(-3.0203)
```

## The `Messenger`

API in more detail¶

Our `LogJointMessenger`

implementation has three important methods:
`__enter__`

, `__exit__`

, and `_pyro_sample`

.

`__enter__`

and `__exit__`

are special methods needed by any Python
context manager. When implementing new `Messenger`

classes, if we
override `__enter__`

and `__exit__`

, we always need to call the base
`Messenger`

’s `__enter__`

and `__exit__`

methods for the new
`Messenger`

to be applied correctly.

The last method `LogJointMessenger._pyro_sample`

, is called once at
each sample site. It reads and modifies a *message*, which is a
dictionary containing the sample site’s name, distribution, sampled or
observed value, and other metadata. We’ll examine the contents of a
message in more detail in the next section.

Instead of `_pyro_sample`

, a generic `Messenger`

actually contains
two methods that are called once per operation where side effects are
performed: 1. `_process_message`

modifies a message and sends the
result to the `Messenger`

just above on the stack 2.
`_postprocess_message`

modifies a message and sends the result to the
next `Messenger`

down on the stack. It is always called after all
active `Messenger`

s have had their `_process_message`

method
applied to the message.

Although custom `Messenger`

s can override `_process_message`

and
`_postprocess_message`

, it’s convenient to avoid requiring all effect
handlers to be aware of all possible effectful operation types. For this
reason, by default `Messenger._process_message`

will use
`msg["type"]`

to dispatch to a corresponding method
`Messenger._pyro_<type>`

, e.g. `Messenger._pyro_sample`

as in
`LogJointMessenger`

. Just as exception handling code ignores unhandled
exception types, this allows `Messenger`

s to simply forward
operations they don’t know how to handle up to the next `Messenger`

in
the stack:

```
class Messenger(object):
...
def _process_message(self, msg):
method_name = "_pyro_{}".format(msg["type"]) # e.g. _pyro_sample when msg["type"] == "sample"
if hasattr(self, method_name):
getattr(self, method_name)(msg)
...
```

### Interlude: the global `Messenger`

stack¶

See pyro.contrib.minipyro for an end-to-end implementation of the mechanism in this section.

The order in which `Messenger`

s are applied to an operation like a
`pyro.sample`

statement is determined by the order in which their
`__enter__`

methods are called. `Messenger.__enter__`

appends a
`Messenger`

to the end (the bottom) of the global handler stack:

```
class Messenger(object):
...
# __enter__ pushes a Messenger onto the stack
def __enter__(self):
...
_PYRO_STACK.append(self)
...
# __exit__ removes a Messenger from the stack
def __exit__(self, ...):
...
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
...
```

`pyro.poutine.runtime.apply_stack`

then traverses the stack twice at
each operation, first from bottom to top to apply each
`_process_message`

and then from top to bottom to apply each
`_postprocess_message`

:

```
def apply_stack(msg): # simplified
for handler in reversed(_PYRO_STACK):
handler._process_message(msg)
...
default_process_message(msg)
...
for handler in _PYRO_STACK:
handler._postprocess_message(msg)
...
return msg
```

### Returning to the `LogJointMessenger`

example¶

The second method `_postprocess_message`

is necessary because some
effects can only be applied after all other effect handlers have had a
chance to update the message once. In the case of `LogJointMessenger`

,
other effects, like enumeration, may modify a sample site’s value or
distribution (`msg["value"]`

or `msg["fn"]`

), so we move the
log-probability computation to a new method, `_pyro_post_sample`

,
which is called by `_postprocess_message`

(via a dispatch mechanism
like the one used by `_process_message`

) at each `sample`

site after
all active handlers’ `_pyro_sample`

methods have been applied:

```
In [7]:
```

```
class LogJointMessenger2(poutine.messenger.Messenger):
def __init__(self, cond_data):
self.data = cond_data
def __call__(self, fn):
def _fn(*args, **kwargs):
with self:
fn(*args, **kwargs)
return self.logp.clone()
return _fn
def __enter__(self):
self.logp = torch.tensor(0.)
return super(LogJointMessenger2, self).__enter__()
def __exit__(self, exc_type, exc_value, traceback):
self.logp = torch.tensor(0.)
return super(LogJointMessenger2, self).__exit__(exc_type, exc_value, traceback)
def _pyro_sample(self, msg):
if msg["name"] in self.data:
msg["value"] = self.data[msg["name"]]
msg["done"] = True
def _pyro_post_sample(self, msg):
assert msg["done"] # the "done" flag asserts that no more modifications to value and fn will be performed.
self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp)
```

```
tensor(-3.0203)
```

## Inside the messages sent by `Messenger`

s¶

As the previous two examples mentioned, the actual messages sent up and down the stack are dictionaries with a particular set of keys. Consider the following sample statement:

```
pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)
```

This sample statement is converted into an initial message before any
effects are applied, and each effect handler’s `_process_message`

and
`_postprocess_message`

may update fields in place or add new fields.
We write out the full initial message here for completeness:

```
msg = {
# The following fields contain the name, inputs, function, and output of a site.
# These are generally the only fields you'll need to think about.
"name": "x",
"fn": dist.Bernoulli(0.5),
"value": None, # msg["value"] will eventually contain the value returned by pyro.sample
"is_observed": False, # because obs=None by default; only used by sample sites
"args": (), # positional arguments passed to "fn" when it is called; usually empty for sample sites
"kwargs": {}, # keyword arguments passed to "fn" when it is called; usually empty for sample sites
# This field typically contains metadata needed or stored by a particular inference algorithm
"infer": {"enumerate": "parallel"},
# The remaining fields are generally only used by Pyro's internals,
# or for implementing more advanced effects beyond the scope of this tutorial
"type": "sample", # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
"done": False,
"stop": False,
"scale": torch.tensor(1.), # Multiplicative scale factor that can be applied to each site's log_prob
"mask": None,
"continuation": None,
"cond_indep_stack": (), # Will contain metadata from each pyro.plate enclosing this sample site.
}
```

Note that when we use `poutine.trace`

or `TraceMessenger`

as in our
first two versions of `make_log_joint`

, the contents of `msg`

are
exactly the information stored in the trace for each sample and param
site.

## Implementing inference algorithms with existing effect handlers: examples¶

It turns out that many inference operations, like our first version of
`make_log_joint`

above, have strikingly short implementations in terms
of existing effect handlers in `pyro.poutine`

.

### Example: Variational inference with a Monte Carlo ELBO¶

For example, here is an implementation of variational inference with a
Monte Carlo ELBO that uses `poutine.trace`

, `poutine.condition`

, and
`poutine.replay`

. This is very similar to the simple ELBO in
pyro.contrib.minipyro.

```
In [8]:
```

```
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
# assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
conditioned_model = poutine.condition(model, data=batch)
# we'll approximate the expectation in the ELBO with a single sample:
# first, we run the guide forward unmodified and record values and distributions
# at each sample site using poutine.trace
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
# we use poutine.replay to set the values of latent variables in the model
# to the values sampled above by our guide, and use poutine.trace
# to record the distributions that appear at each sample site in in the model
model_trace = poutine.trace(
poutine.replay(conditioned_model, trace=guide_trace)
).get_trace(*args, **kwargs)
elbo = 0.
for name, node in model_trace.nodes.items():
if node["type"] == "sample":
elbo = elbo + node["fn"].log_prob(node["value"]).sum()
if not node["is_observed"]:
elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
return -elbo
```

We use `poutine.trace`

and `poutine.block`

to record `pyro.param`

calls for optimization:

```
In [9]:
```

```
def train(model, guide, data):
optimizer = pyro.optim.Adam({})
for batch in data:
# this poutine.trace will record all of the parameters that appear in the model and guide
# during the execution of monte_carlo_elbo
with poutine.trace() as param_capture:
# we use poutine.block here so that only parameters appear in the trace above
with poutine.block(hide_fn=lambda node: node["type"] != "param"):
loss = monte_carlo_elbo(model, guide, batch)
loss.backward()
params = set(node["value"].unconstrained()
for node in param_capture.trace.nodes.values())
optimizer.step(params)
pyro.infer.util.zero_grads(params)
```

### Example: exact inference via sequential enumeration¶

Here is an example of a very different inference algorithm–exact
inference via enumeration–implemented with `pyro.poutine`

. A complete
explanation of this algorithm is beyond the scope of this tutorial and
may be found in Chapter 3 of the short online book Design and
Implementation of Probabilistic Programming
Languages. This
example uses `poutine.queue`

, itself implemented using
`poutine.trace`

, `poutine.replay`

, and `poutine.block`

, to
enumerate over possible values of all discrete variables in a model and
compute a marginal distribution over all possible return values or the
possible values at a particular sample site:

```
In [10]:
```

```
def sequential_discrete_marginal(model, data, site_name="_RETURN"):
from six.moves import queue # queue data structures
q = queue.Queue() # Instantiate a first-in first-out queue
q.put(poutine.Trace()) # seed the queue with an empty trace
# as before, we fix the values of observed random variables with poutine.condition
# assuming data is a dictionary whose keys are names of sample sites in model
conditioned_model = poutine.condition(model, data=data)
# we wrap the conditioned model in a poutine.queue,
# which repeatedly pushes and pops partially completed executions from a Queue()
# to perform breadth-first enumeration over the set of values of all discrete sample sites in model
enum_model = poutine.queue(conditioned_model, queue=q)
# actually perform the enumeration by repeatedly tracing enum_model
# and accumulate samples and trace log-probabilities for postprocessing
samples, log_weights = [], []
while not q.empty():
trace = poutine.trace(enum_model).get_trace()
samples.append(trace.nodes[site_name]["value"])
log_weights.append(trace.log_prob_sum())
# we take the samples and log-joints and turn them into a histogram:
samples = torch.stack(samples, 0)
log_weights = torch.stack(log_weights, 0)
log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
return dist.Empirical(samples, log_weights)
```

(Note that `sequential_discrete_marginal`

is very general, but is also
quite slow. For high-performance parallel enumeration that applies to a
less general class of models, see the enumeration tutorial.)

## Example: implementing lazy evaluation with the `Messenger`

API¶

Now that we’ve learned more about the internals of `Messenger`

, let’s
use it to implement a slightly more complicated effect: lazy evaluation.
We first define a `LazyValue`

class that we will use to build up a
computation graph:

```
In [11]:
```

```
class LazyValue(object):
def __init__(self, fn, *args, **kwargs):
self._expr = (fn, args, kwargs)
self._value = None
def __str__(self):
return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))
def evaluate(self):
if self._value is None:
fn, args, kwargs = self._expr
fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg
for arg in args)
kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v
for k, v in kwargs.items()}
self._value = fn(*args, **kwargs)
return self._value
```

With `LazyValue`

, implementing lazy evaluation as a `Messenger`

compatible with other effect handlers is suprisingly easy. We just make
each `msg["value"]`

a `LazyValue`

and introduce a new operation type
`"apply"`

for deterministic operations:

```
In [12]:
```

```
class LazyMessenger(pyro.poutine.messenger.Messenger):
def _process_message(self, msg):
if msg["type"] in ("apply", "sample") and not msg["done"]:
msg["done"] = True
msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])
```

Finally, just like `torch.autograd`

overloads `torch`

tensor
operations to record an autograd graph, we need to wrap any operations
we’d like to be lazy. We’ll use `pyro.poutine.runtime.effectful`

as a
decorator to expose these operations to `LazyMessenger`

. `effectful`

constructs a message much like the one above and sends it up and down
the effect handler stack, but allows us to set the type (in this case,
to `"apply"`

instead of `"sample"`

) so that these operations aren’t
mistaken for `sample`

statements by other effect handlers like
`TraceMessenger`

:

```
In [13]:
```

```
@effectful(type="apply")
def add(x, y):
return x + y
@effectful(type="apply")
def mul(x, y):
return x * y
@effectful(type="apply")
def sigmoid(x):
return torch.sigmoid(x)
@effectful(type="apply")
def normal(loc, scale):
return dist.Normal(loc, scale)
```

Applied to another model:

```
In [14]:
```

```
def biased_scale(guess):
weight = pyro.sample("weight", normal(guess, 1.))
tolerance = pyro.sample("tolerance", normal(0., 0.25))
return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))
with LazyMessenger():
v = biased_scale(8.5)
print(v)
print(v.evaluate())
```

```
((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )
tensor(6.5436)
```

Together with other effect handlers like `TraceMessenger`

and
`ConditionMessenger`

, with which it freely composes, `LazyMessenger`

demonstrates how to use Poutine to quickly and concisely implement
state-of-the-art PPL techniques like delayed sampling with
Rao-Blackwellization.

## References: algebraic effects and handlers in programming language research¶

This section contains some references to PL papers for readers interested in this direction.

Algebraic effects and handlers, which were developed starting in the
early 2000s and are a subject of active research in the programming
languages community, are a versatile abstraction for building modular
implementations of nonstandard interpreters of particular statements in
a programming language, like `pyro.sample`

or `pyro.param`

. They
were originally introduced to address the difficulty of composing
nonstandard interpreters implemented with monads and monad transformers.

- For an accessible introduction to the effect handlers literature, see the excellent review/tutorial paper “Handlers in Action” by Ohad Kammar, Sam Lindley, and Nicolas Oury, and the references therein.
- Algebraic effect handlers were originally introduced by Gordon Plotkin and Matija Pretnar in the paper “Handlers of Algebraic Effects”.
- A useful mental model of effect handlers is as exception handlers
that are capable of resuming computation in the
`try`

block after raising an exception and performing some processing in the`except`

block. This metaphor is explored further in the experimental programming language Eff and its companion paper “Programming with Algebraic Effects and Handlers” by Andrej Bauer and Matija Pretnar. - Most effect handlers in Pyro are “linear,” meaning that they only
resume once per effectful operation and do not alter the order of
execution of the original program. One exception is
`poutine.queue`

, which uses an inefficient implementation strategy for multiple resumptions like the one described for delimited continuations in the paper “Capturing the Future by Replaying the Past” by James Koppel, Gabriel Scherer, and Armando Solar-Lezama. - More efficient implementation strategies for effect handlers in mainstream programming languages like Python or JavaScript is an area of active research. One promising line of work involves selective continuation-passing style transforms as in the paper “Type-Directed Compilation of Row-Typed Algebraic Effects” by Daan Leijen.