Poutine: A Guide to Programming with Effect Handlers in Pyro¶
Note to readers: This tutorial is a guide to the API details of Pyro’s effect handling library, Poutine. We recommend readers first orient themselves with the simplified minipyro.py which contains a minimal, readable implementation of Pyro’s runtime and the effect handler abstraction described here. Pyro’s effect handler library is more general than minipyro’s but also contains more layers of indirection; it helps to read them side-by-side.
[1]:
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.poutine.runtime import effectful
pyro.set_rng_seed(101)
Introduction¶
Inference in probabilistic programming involves manipulating or transforming probabilistic programs written as generative models. For example, nearly all approximate inference algorithms require computing the unnormalized joint probability of values of latent and observed variables under a generative model.
Consider the following example model:
[2]:
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))
This model defines a joint probability distribution over "weight"
and "measurement"
:
If we had access to the inputs and outputs of each pyro.sample
site, we could compute their log-joint:
logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()
However, the way we wrote scale
above does not seem to expose these intermediate distribution objects, and rewriting it to return them would be intrusive and would violate the separation of concerns between models and inference algorithms that a probabilistic programming language like Pyro is designed to enforce.
To resolve this conflict and facilitate inference algorithm development, Pyro exposes Poutine, a library of effect handlers, or composable building blocks for examining and modifying the behavior of Pyro programs. Most of Pyro’s internals are implemented on top of Poutine.
A first look at Poutine: Pyro’s library of algorithmic building blocks¶
Effect handlers, a common abstraction in the programming languages community, give nonstandard interpretations or side effects to the behavior of particular statements in a programming language, like pyro.sample
or pyro.param
. For background reading on effect handlers in programming language research, see the optional “References” section at the end of this tutorial.
Rather than reviewing more definitions, let’s look at a first example that addresses the problem above: we can compose two existing effect handlers, poutine.condition
(which sets output values of pyro.sample
statements) and poutine.trace
(which records the inputs, distributions, and outputs of pyro.sample
statements), to concisely define a new effect handler that computes the log-joint:
[3]:
def make_log_joint(model):
def _log_joint(cond_data, *args, **kwargs):
conditioned_model = poutine.condition(model, data=cond_data)
trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
return trace.log_prob_sum()
return _log_joint
scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5)))
tensor(-3.0203)
That snippet is short, but still somewhat opaque - poutine.condition
, poutine.trace
, and trace.log_prob_sum
are all black boxes. Let’s remove a layer of boilerplate from poutine.condition
and poutine.trace
and explicitly implement what trace.log_prob_sum
is doing:
[4]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger
def make_log_joint_2(model):
def _log_joint(cond_data, *args, **kwargs):
with TraceMessenger() as tracer:
with ConditionMessenger(data=cond_data):
model(*args, **kwargs)
trace = tracer.trace
logp = 0.
for name, node in trace.nodes.items():
if node["type"] == "sample":
if node["is_observed"]:
assert node["value"] is cond_data[name]
logp = logp + node["fn"].log_prob(node["value"]).sum()
return logp
return _log_joint
scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
tensor(-3.0203)
This makes things a little more clear: we can now see that poutine.trace
and poutine.condition
are wrappers for context managers that presumably communicate with the model through something inside pyro.sample
. We can also see that poutine.trace
produces a data structure (a Trace) containing a dictionary whose keys are sample
site names and values are dictionaries containing the distribution ("fn"
) and output
("value"
) at each site, and that the output values at each site are exactly the values specified in data
.
Finally, TraceMessenger
and ConditionMessenger
are Pyro effect handlers, or Messenger
s: stateful context manager objects that are placed on a global stack and send messages (hence the name) up and down the stack at each effectful operation, like a pyro.sample
call. A Messenger
is placed at the bottom of the stack when its __enter__
method is called, i.e. when it is used in a “with” statement.
We’ll look at this process in more detail later in this tutorial. For a simplified implementation in only a few lines of code, see pyro.contrib.minipyro.
Implementing new effect handlers with the Messenger
API¶
Although it’s easiest to build new effect handlers by composing the existing ones in pyro.poutine
, implementing a new effect as a pyro.poutine.messenger.Messenger
subclass is actually fairly straightforward. Before diving into the API, let’s look at another example: a version of our log-joint computation that performs the sum while the model is executing. We’ll then review what each part of the example is actually doing.
[5]:
class LogJointMessenger(poutine.messenger.Messenger):
def __init__(self, cond_data):
self.data = cond_data
# __call__ is syntactic sugar for using Messengers as higher-order functions.
# Messenger already defines __call__, but we re-define it here
# for exposition and to change the return value:
def __call__(self, fn):
def _fn(*args, **kwargs):
with self:
fn(*args, **kwargs)
return self.logp.clone()
return _fn
def __enter__(self):
self.logp = torch.tensor(0.)
# All Messenger subclasses must call the base Messenger.__enter__()
# in their __enter__ methods
return super().__enter__()
# __exit__ takes the same arguments in all Python context managers
def __exit__(self, exc_type, exc_value, traceback):
self.logp = torch.tensor(0.)
# All Messenger subclasses must call the base Messenger.__exit__ method
# in their __exit__ methods.
return super().__exit__(exc_type, exc_value, traceback)
# _pyro_sample will be called once per pyro.sample site.
# It takes a dictionary msg containing the name, distribution,
# observation or sample value, and other metadata from the sample site.
def _pyro_sample(self, msg):
# Any unobserved random variables will trigger this assertion.
# In the next section, we'll learn how to also handle sampled values.
assert msg["name"] in self.data
msg["value"] = self.data[msg["name"]]
# Since we've observed a value for this site, we set the "is_observed" flag to True
# This tells any other Messengers not to overwrite msg["value"] with a sample.
msg["is_observed"] = True
self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp.clone())
scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
tensor(-3.0203)
tensor(-3.0203)
A convenient bit of boilerplate that allows the use of LogJointMessenger
as a context manager, decorator, or higher-order function is the following. Most of the existing effect handlers in pyro.poutine
, including poutine.trace
and poutine.condition
which we used earlier, are Messenger
s wrapped this way in pyro.poutine.handlers
.
[6]:
def log_joint(model=None, cond_data=None):
msngr = LogJointMessenger(cond_data=cond_data)
return msngr(model) if model is not None else msngr
scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))
tensor(-3.0203)
The Messenger
API in more detail¶
Our LogJointMessenger
implementation has three important methods: __enter__
, __exit__
, and _pyro_sample
.
__enter__
and __exit__
are special methods needed by any Python context manager. When implementing new Messenger
classes, if we override __enter__
and __exit__
, we always need to call the base Messenger
’s __enter__
and __exit__
methods for the new Messenger
to be applied correctly.
The last method LogJointMessenger._pyro_sample
, is called once at each sample site. It reads and modifies a message, which is a dictionary containing the sample site’s name, distribution, sampled or observed value, and other metadata. We’ll examine the contents of a message in more detail in the next section.
Instead of _pyro_sample
, a generic Messenger
actually contains two methods that are called once per operation where side effects are performed: 1. _process_message
modifies a message and sends the result to the Messenger
just above on the stack 2. _postprocess_message
modifies a message and sends the result to the next Messenger
down on the stack. It is always called after all active Messenger
s have had their _process_message
method applied to the message.
Although custom Messenger
s can override _process_message
and _postprocess_message
, it’s convenient to avoid requiring all effect handlers to be aware of all possible effectful operation types. For this reason, by default Messenger._process_message
will use msg["type"]
to dispatch to a corresponding method Messenger._pyro_<type>
, e.g. Messenger._pyro_sample
as in LogJointMessenger
. Just as exception handling code ignores unhandled exception types, this allows
Messenger
s to simply forward operations they don’t know how to handle up to the next Messenger
in the stack:
class Messenger:
...
def _process_message(self, msg):
method_name = "_pyro_{}".format(msg["type"]) # e.g. _pyro_sample when msg["type"] == "sample"
if hasattr(self, method_name):
getattr(self, method_name)(msg)
...
Interlude: the global Messenger
stack¶
See pyro.contrib.minipyro for an end-to-end implementation of the mechanism in this section.
The order in which Messenger
s are applied to an operation like a pyro.sample
statement is determined by the order in which their __enter__
methods are called. Messenger.__enter__
appends a Messenger
to the end (the bottom) of the global handler stack:
class Messenger:
...
# __enter__ pushes a Messenger onto the stack
def __enter__(self):
...
_PYRO_STACK.append(self)
...
# __exit__ removes a Messenger from the stack
def __exit__(self, ...):
...
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
...
pyro.poutine.runtime.apply_stack
then traverses the stack twice at each operation, first from bottom to top to apply each _process_message
and then from top to bottom to apply each _postprocess_message
:
def apply_stack(msg): # simplified
for handler in reversed(_PYRO_STACK):
handler._process_message(msg)
...
default_process_message(msg)
...
for handler in _PYRO_STACK:
handler._postprocess_message(msg)
...
return msg
Returning to the LogJointMessenger
example¶
The second method _postprocess_message
is necessary because some effects can only be applied after all other effect handlers have had a chance to update the message once. In the case of LogJointMessenger
, other effects, like enumeration, may modify a sample site’s value or distribution (msg["value"]
or msg["fn"]
), so we move the log-probability computation to a new method, _pyro_post_sample
, which is called by _postprocess_message
(via a dispatch mechanism like the one
used by _process_message
) at each sample
site after all active handlers’ _pyro_sample
methods have been applied:
[7]:
class LogJointMessenger2(poutine.messenger.Messenger):
def __init__(self, cond_data):
self.data = cond_data
def __call__(self, fn):
def _fn(*args, **kwargs):
with self:
fn(*args, **kwargs)
return self.logp.clone()
return _fn
def __enter__(self):
self.logp = torch.tensor(0.)
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
self.logp = torch.tensor(0.)
return super().__exit__(exc_type, exc_value, traceback)
def _pyro_sample(self, msg):
if msg["name"] in self.data:
msg["value"] = self.data[msg["name"]]
msg["done"] = True
def _pyro_post_sample(self, msg):
assert msg["done"] # the "done" flag asserts that no more modifications to value and fn will be performed.
self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
scale(8.5)
print(m.logp)
tensor(-3.0203)
Inside the messages sent by Messenger
s¶
As the previous two examples mentioned, the actual messages sent up and down the stack are dictionaries with a particular set of keys. Consider the following sample statement:
pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)
This sample statement is converted into an initial message before any effects are applied, and each effect handler’s _process_message
and _postprocess_message
may update fields in place or add new fields. We write out the full initial message here for completeness:
msg = {
# The following fields contain the name, inputs, function, and output of a site.
# These are generally the only fields you'll need to think about.
"name": "x",
"fn": dist.Bernoulli(0.5),
"value": None, # msg["value"] will eventually contain the value returned by pyro.sample
"is_observed": False, # because obs=None by default; only used by sample sites
"args": (), # positional arguments passed to "fn" when it is called; usually empty for sample sites
"kwargs": {}, # keyword arguments passed to "fn" when it is called; usually empty for sample sites
# This field typically contains metadata needed or stored by a particular inference algorithm
"infer": {"enumerate": "parallel"},
# The remaining fields are generally only used by Pyro's internals,
# or for implementing more advanced effects beyond the scope of this tutorial
"type": "sample", # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
"done": False,
"stop": False,
"scale": torch.tensor(1.), # Multiplicative scale factor that can be applied to each site's log_prob
"mask": None,
"continuation": None,
"cond_indep_stack": (), # Will contain metadata from each pyro.plate enclosing this sample site.
}
Note that when we use poutine.trace
or TraceMessenger
as in our first two versions of make_log_joint
, the contents of msg
are exactly the information stored in the trace for each sample and param site.
Implementing inference algorithms with existing effect handlers: examples¶
It turns out that many inference operations, like our first version of make_log_joint
above, have strikingly short implementations in terms of existing effect handlers in pyro.poutine
.
Example: Variational inference with a Monte Carlo ELBO¶
For example, here is an implementation of variational inference with a Monte Carlo ELBO that uses poutine.trace
, poutine.condition
, and poutine.replay
. This is very similar to the simple ELBO in pyro.contrib.minipyro.
[8]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
# assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
conditioned_model = poutine.condition(model, data=batch)
# we'll approximate the expectation in the ELBO with a single sample:
# first, we run the guide forward unmodified and record values and distributions
# at each sample site using poutine.trace
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
# we use poutine.replay to set the values of latent variables in the model
# to the values sampled above by our guide, and use poutine.trace
# to record the distributions that appear at each sample site in in the model
model_trace = poutine.trace(
poutine.replay(conditioned_model, trace=guide_trace)
).get_trace(*args, **kwargs)
elbo = 0.
for name, node in model_trace.nodes.items():
if node["type"] == "sample":
elbo = elbo + node["fn"].log_prob(node["value"]).sum()
if not node["is_observed"]:
elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
return -elbo
We use poutine.trace
and poutine.block
to record pyro.param
calls for optimization:
[9]:
def train(model, guide, data):
optimizer = pyro.optim.Adam({})
for batch in data:
# this poutine.trace will record all of the parameters that appear in the model and guide
# during the execution of monte_carlo_elbo
with poutine.trace() as param_capture:
# we use poutine.block here so that only parameters appear in the trace above
with poutine.block(hide_fn=lambda node: node["type"] != "param"):
loss = monte_carlo_elbo(model, guide, batch)
loss.backward()
params = set(node["value"].unconstrained()
for node in param_capture.trace.nodes.values())
optimizer.step(params)
pyro.infer.util.zero_grads(params)
Example: exact inference via sequential enumeration¶
Here is an example of a very different inference algorithm–exact inference via enumeration–implemented with pyro.poutine
. A complete explanation of this algorithm is beyond the scope of this tutorial and may be found in Chapter 3 of the short online book Design and Implementation of Probabilistic Programming Languages. This example uses poutine.queue
, itself implemented using poutine.trace
, poutine.replay
, and poutine.block
,
to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site:
[10]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):
from six.moves import queue # queue data structures
q = queue.Queue() # Instantiate a first-in first-out queue
q.put(poutine.Trace()) # seed the queue with an empty trace
# as before, we fix the values of observed random variables with poutine.condition
# assuming data is a dictionary whose keys are names of sample sites in model
conditioned_model = poutine.condition(model, data=data)
# we wrap the conditioned model in a poutine.queue,
# which repeatedly pushes and pops partially completed executions from a Queue()
# to perform breadth-first enumeration over the set of values of all discrete sample sites in model
enum_model = poutine.queue(conditioned_model, queue=q)
# actually perform the enumeration by repeatedly tracing enum_model
# and accumulate samples and trace log-probabilities for postprocessing
samples, log_weights = [], []
while not q.empty():
trace = poutine.trace(enum_model).get_trace()
samples.append(trace.nodes[site_name]["value"])
log_weights.append(trace.log_prob_sum())
# we take the samples and log-joints and turn them into a histogram:
samples = torch.stack(samples, 0)
log_weights = torch.stack(log_weights, 0)
log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
return dist.Empirical(samples, log_weights)
(Note that sequential_discrete_marginal
is very general, but is also quite slow. For high-performance parallel enumeration that applies to a less general class of models, see the enumeration tutorial.)
Example: implementing lazy evaluation with the Messenger
API¶
Now that we’ve learned more about the internals of Messenger
, let’s use it to implement a slightly more complicated effect: lazy evaluation. We first define a LazyValue
class that we will use to build up a computation graph:
[11]:
class LazyValue:
def __init__(self, fn, *args, **kwargs):
self._expr = (fn, args, kwargs)
self._value = None
def __str__(self):
return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))
def evaluate(self):
if self._value is None:
fn, args, kwargs = self._expr
fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg
for arg in args)
kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v
for k, v in kwargs.items()}
self._value = fn(*args, **kwargs)
return self._value
With LazyValue
, implementing lazy evaluation as a Messenger
compatible with other effect handlers is suprisingly easy. We just make each msg["value"]
a LazyValue
and introduce a new operation type "apply"
for deterministic operations:
[12]:
class LazyMessenger(pyro.poutine.messenger.Messenger):
def _process_message(self, msg):
if msg["type"] in ("apply", "sample") and not msg["done"]:
msg["done"] = True
msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])
Finally, just like torch.autograd
overloads torch
tensor operations to record an autograd graph, we need to wrap any operations we’d like to be lazy. We’ll use pyro.poutine.runtime.effectful
as a decorator to expose these operations to LazyMessenger
. effectful
constructs a message much like the one above and sends it up and down the effect handler stack, but allows us to set the type (in this case, to "apply"
instead of "sample"
) so that these operations aren’t
mistaken for sample
statements by other effect handlers like TraceMessenger
:
[13]:
@effectful(type="apply")
def add(x, y):
return x + y
@effectful(type="apply")
def mul(x, y):
return x * y
@effectful(type="apply")
def sigmoid(x):
return torch.sigmoid(x)
@effectful(type="apply")
def normal(loc, scale):
return dist.Normal(loc, scale)
Applied to another model:
[14]:
def biased_scale(guess):
weight = pyro.sample("weight", normal(guess, 1.))
tolerance = pyro.sample("tolerance", normal(0., 0.25))
return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))
with LazyMessenger():
v = biased_scale(8.5)
print(v)
print(v.evaluate())
((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )
tensor(6.5436)
Together with other effect handlers like TraceMessenger
and ConditionMessenger
, with which it freely composes, LazyMessenger
demonstrates how to use Poutine to quickly and concisely implement state-of-the-art PPL techniques like delayed sampling with Rao-Blackwellization.
References: algebraic effects and handlers in programming language research¶
This section contains some references to PL papers for readers interested in this direction.
Algebraic effects and handlers, which were developed starting in the early 2000s and are a subject of active research in the programming languages community, are a versatile abstraction for building modular implementations of nonstandard interpreters of particular statements in a programming language, like pyro.sample
or pyro.param
. They were originally introduced to address the difficulty of composing nonstandard interpreters implemented with monads and monad transformers.
For an accessible introduction to the effect handlers literature, see the excellent review/tutorial paper “Handlers in Action” by Ohad Kammar, Sam Lindley, and Nicolas Oury, and the references therein.
Algebraic effect handlers were originally introduced by Gordon Plotkin and Matija Pretnar in the paper “Handlers of Algebraic Effects”.
A useful mental model of effect handlers is as exception handlers that are capable of resuming computation in the
try
block after raising an exception and performing some processing in theexcept
block. This metaphor is explored further in the experimental programming language Eff and its companion paper “Programming with Algebraic Effects and Handlers” by Andrej Bauer and Matija Pretnar.Most effect handlers in Pyro are “linear,” meaning that they only resume once per effectful operation and do not alter the order of execution of the original program. One exception is
poutine.queue
, which uses an inefficient implementation strategy for multiple resumptions like the one described for delimited continuations in the paper “Capturing the Future by Replaying the Past” by James Koppel, Gabriel Scherer, and Armando Solar-Lezama.More efficient implementation strategies for effect handlers in mainstream programming languages like Python or JavaScript is an area of active research. One promising line of work involves selective continuation-passing style transforms as in the paper “Type-Directed Compilation of Row-Typed Algebraic Effects” by Daan Leijen.