pyro.contrib.funsor
, a new backend for Pyro - New primitives (Part 1)¶
Introduction¶
In this tutorial we’ll cover the basics of pyro.contrib.funsor
, a new backend for the Pyro probabilistic programming system that is intended to replace the current internals of Pyro and significantly expand its capabilities as both a modelling tool and an inference research platform.
This tutorial is aimed at readers interested in developing custom inference algorithms and understanding Pyro’s current and future internals. As such, the material here assumes some familiarity with the generic Pyro API package pyroapi
and with Funsor. Additional documentation for Funsor can be found on the Pyro website, on GitHub, and in the research paper “Functional Tensors for
Probabilistic Programming.” Those who are less interested in such details should find that they can already use the general-purpose algorithms in contrib.funsor
with their existing Pyro models via pyroapi
.
Reinterpreting existing Pyro models with pyroapi
¶
The new backend uses the pyroapi
package to integrate with existing Pyro code.
First, we import some dependencies:
[1]:
from collections import OrderedDict
import torch
import funsor
from pyro import set_rng_seed as pyro_set_rng_seed
funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)
Importing pyro.contrib.funsor
registers the "contrib.funsor"
backend with pyroapi
, which can now be passed as an argument to the pyroapi.pyro_backend
context manager.
[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import handlers, infer, ops, optim, pyro
from pyroapi import distributions as dist
# this is already done in pyro.contrib.funsor, but we repeat it here
pyroapi.register_backend("contrib.funsor", dict(
distributions="pyro.distributions",
handlers="pyro.contrib.funsor.handlers",
infer="pyro.contrib.funsor.infer",
ops="torch",
optim="pyro.optim",
pyro="pyro.contrib.funsor",
))
And we’re off! From here on, any pyro.(...)
statement should be understood as dispatching to the new backend.
Two new primitives: to_funsor
and to_data
¶
The first and most important new concept in pyro.contrib.funsor
is the new pair of primitives pyro.to_funsor
and pyro.to_data
.
These are effectful versions of funsor.to_funsor
and funsor.to_data
, i.e. versions whose behavior can be intercepted, controlled, or used to trigger side effects by Pyro’s library of algebraic effect handlers. Let’s briefly review these two underlying functions before diving into the effectful versions in pyro.contrib.funsor
.
As one might expect from the name, to_funsor
takes as inputs objects that are not funsor.Funsor
s and attempts to convert them into Funsor terms. For example, calling funsor.to_funsor
on a Python number converts it to a funsor.terms.Number
object:
[3]:
funsor_one = funsor.to_funsor(float(1))
print(funsor_one, type(funsor_one))
funsor_two = funsor.to_funsor(torch.tensor(2.))
print(funsor_two, type(funsor_two))
1.0 <class 'funsor.terms.Number'>
tensor(2.) <class 'funsor.tensor.Tensor'>
Similarly ,calling funsor.to_data
on an atomic funsor.Funsor
converts it to a regular Python object like a float
or a torch.Tensor
:
[4]:
data_one = funsor.to_data(funsor.terms.Number(float(1), 'real'))
print(data_one, type(data_one))
data_two = funsor.to_data(funsor.Tensor(torch.tensor(2.), OrderedDict(), 'real'))
print(data_two, type(data_two))
1.0 <class 'float'>
tensor(2.) <class 'torch.Tensor'>
In many cases it is necessary to provide an output type to uniquely convert a piece of data to a funsor.Funsor
. This also means that, strictly speaking, funsor.to_funsor
and funsor.to_data
are not inverses. For example, funsor.to_funsor
will automatically convert Python strings to funsor.Variable
s, but only when given an output funsor.domains.Domain
, which serves as the type of the variable:
[5]:
var_x = funsor.to_funsor("x", output=funsor.Reals[2])
print(var_x, var_x.inputs, var_x.output)
x OrderedDict([('x', Reals[2])]) Reals[2]
However, it is often impossible to convert objects to and from Funsor expressions uniquely without additional type information about inputs, as in the following example of a torch.Tensor
, which could be converted to a funsor.Tensor
in several ways.
To resolve this ambiguity, we need to provide to_funsor
and to_data
with type information that describes how to convert positional dimensions to and from unordered named Funsor dimensions. This information comes in the form of dictionaries mapping batch dimensions to dimension names or vice versa.
A key property of these mappings is that use the convention that dimension indices refer to batch dimensions, or dimensions not included in the output shape, which is treated as referring to the rightmost portion of the underlying PyTorch tensor shape, as illustrated in the example below.
[6]:
ambiguous_tensor = torch.zeros((3, 1, 2))
print("Ambiguous tensor: shape = {}".format(ambiguous_tensor.shape))
# case 1: treat all dimensions as output/event dimensions
funsor1 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[3, 1, 2])
print("Case 1: inputs = {}, output = {}".format(funsor1.inputs, funsor1.output))
# case 2: treat the leftmost dimension as a batch dimension
# note that dimension -1 in dim_to_name here refers to the rightmost *batch dimension*,
# i.e. dimension -3 of ambiguous_tensor, the rightmost dimension not included in the output shape.
funsor2 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[1, 2], dim_to_name={-1: "a"})
print("Case 2: inputs = {}, output = {}".format(funsor2.inputs, funsor2.output))
# case 3: treat the leftmost 2 dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1 and -2 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor3 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[2], dim_to_name={-1: "b", -2: "a"})
print("Case 3: inputs = {}, output = {}".format(funsor3.inputs, funsor3.output))
# case 4: treat all dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1, -2 and -3 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -1, -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor4 = funsor.to_funsor(ambiguous_tensor, output=funsor.Real, dim_to_name={-1: "c", -2: "b", -3: "a"})
print("Case 4: inputs = {}, output = {}".format(funsor4.inputs, funsor4.output))
Ambiguous tensor: shape = torch.Size([3, 1, 2])
Case 1: inputs = OrderedDict(), output = Reals[3,1,2]
Case 2: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[1,2]
Case 3: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[2]
Case 4: inputs = OrderedDict([('a', Bint[3, ]), ('c', Bint[2, ])]), output = Real
Similar ambiguity exists for to_data
: the inputs
of a funsor.Funsor
are ordered arbitrarily, and empty dimensions in the data are squeezed away, so a mapping from names to batch dimensions must be provided to ensure unique conversion:
[7]:
ambiguous_funsor = funsor.Tensor(torch.zeros((3, 2)), OrderedDict(a=funsor.Bint[3], b=funsor.Bint[2]), 'real')
print("Ambiguous funsor: inputs = {}, shape = {}".format(ambiguous_funsor.inputs, ambiguous_funsor.output))
# case 1: the simplest version
tensor1 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -2, "b": -1})
print("Case 1: shape = {}".format(tensor1.shape))
# case 2: an empty dimension between a and b
tensor2 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -3, "b": -1})
print("Case 2: shape = {}".format(tensor2.shape))
# case 3: permuting the input dimensions
tensor3 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -1, "b": -2})
print("Case 3: shape = {}".format(tensor3.shape))
Ambiguous funsor: inputs = OrderedDict([('a', Bint[3, ]), ('b', Bint[2, ])]), shape = Real
Case 1: shape = torch.Size([3, 2])
Case 2: shape = torch.Size([3, 1, 2])
Case 3: shape = torch.Size([2, 3])
Maintaining and updating this information efficiently becomes tedious and error-prone as the number of conversions increases. Fortunately, it can be automated away completely. Consider the following example:
[8]:
name_to_dim = OrderedDict()
funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
name_to_dim.update({"x": -1})
tensor_x = funsor.to_data(funsor_x, name_to_dim=name_to_dim)
print(name_to_dim, funsor_x.inputs, tensor_x.shape)
funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
name_to_dim.update({"y": -2})
tensor_y = funsor.to_data(funsor_y, name_to_dim=name_to_dim)
print(name_to_dim, funsor_y.inputs, tensor_y.shape)
funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
name_to_dim.update({"z": -3})
tensor_z = funsor.to_data(funsor_z, name_to_dim=name_to_dim)
print(name_to_dim, funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', -1)]) OrderedDict([('x', Bint[2, ])]) torch.Size([2])
OrderedDict([('x', -1), ('y', -2)]) OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2])
OrderedDict([('x', -1), ('y', -2), ('z', -3)]) OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1])
This is exactly the functionality provided by pyro.to_funsor
and pyro.to_data
, as we can see by using them in the previous example and removing the manual updates. We must also wrap the function in a handlers.named
effect handler to ensure that the dimension dictionaries do not persist beyond the function body.
[9]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
tensor_x = pyro.to_data(funsor_x)
print(funsor_x.inputs, tensor_x.shape)
funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
tensor_y = pyro.to_data(funsor_y)
print(funsor_y.inputs, tensor_y.shape)
funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
tensor_z = pyro.to_data(funsor_z)
print(funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', Bint[2, ])]) torch.Size([2, 1, 1, 1, 1])
OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2, 1, 1, 1, 1])
OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1, 1, 1, 1, 1])
Critically, pyro.to_funsor
and pyro.to_data
use and update the same bidirectional mapping between names and dimensions, allowing them to be combined intuitively. A typical usage pattern, and one that pyro.contrib.funsor
uses heavily in its inference algorithm implementations, is to create a funsor.Funsor
term directly with a new named dimension and call pyro.to_data
on it, perform some PyTorch computations, and call pyro.to_funsor
on the result:
[10]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
probs = funsor.Tensor(torch.tensor([0.5, 0.4, 0.7]), OrderedDict(batch=funsor.Bint[3]))
print(type(probs), probs.inputs, probs.output)
x = funsor.Tensor(torch.tensor([0., 1., 0., 1.]), OrderedDict(x=funsor.Bint[4]))
print(type(x), x.inputs, x.output)
dx = dist.Bernoulli(pyro.to_data(probs))
print(type(dx), dx.shape())
px = pyro.to_funsor(dx.log_prob(pyro.to_data(x)), output=funsor.Real)
print(type(px), px.inputs, px.output)
<class 'funsor.tensor.Tensor'> OrderedDict([('batch', Bint[3, ])]) Real
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ])]) Real
<class 'pyro.distributions.torch.Bernoulli'> torch.Size([3, 1, 1, 1, 1])
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ]), ('batch', Bint[3, ])]) Real
pyro.to_funsor
and pyro.to_data
treat the keys in their name-to-dim mappings as references to the input’s batch shape, but treats the values as references to the globally consistent name-dim mapping. This may be useful for complicated computations that involve a mixture of PyTorch and Funsor operations.
[11]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"})
print("x: ", type(x), x.inputs, x.output)
px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"})
print("px: ", type(px), px.inputs, px.output)
x: <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ])]) Real
px: <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ]), ('y', Bint[3, ])]) Real
Dealing with large numbers of variables: (re-)introducing pyro.markov
¶
So far, so good. However, what if the number of different named dimensions continues to increase? We face two problems: first, reusing the fixed number of available positional dimensions (25 in PyTorch), and second, computing shape information with time complexity that is independent of the number of variables.
A fully general automated solution to this problem would require deeper integration with Python or PyTorch. Instead, as an intermediate solution, we introduce the second key concept in pyro.contrib.funsor
: the pyro.markov
annotation, a way to indicate the shelf life of certain variables. pyro.markov
is already part of Pyro (see enumeration tutorial) but the implementation in pyro.contrib.funsor
is fresh.
The primary constraint on the design of pyro.markov
is backwards compatibility: in order for pyro.contrib.funsor
to be compatible with the large range of existing Pyro models, the new implementation had to match the shape semantics of Pyro’s existing enumeration machinery as closely as possible.
[12]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
for i in pyro.markov(range(10)):
x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]: torch.Size([2, 1, 1, 1, 1])
Shape of x[1]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]: torch.Size([2, 1, 1, 1, 1])
Shape of x[3]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[4]: torch.Size([2, 1, 1, 1, 1])
Shape of x[5]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[6]: torch.Size([2, 1, 1, 1, 1])
Shape of x[7]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]: torch.Size([2, 1, 1, 1, 1])
Shape of x[9]: torch.Size([2, 1, 1, 1, 1, 1])
pyro.markov
is a versatile piece of syntax that can be used as a context manager, a decorator, or an iterator. It is important to understand that pyro.markov
’s only functionality at present is tracking variable usage, not directly indicating conditional independence properties to inference algorithms, and as such it is only necessary to add enough annotations to ensure that tensors have correct shapes, rather than attempting to manually encode as much dependency information as
possible.
pyro.markov
takes an additional argument history
that determines the number of previous pyro.markov
contexts to take into account when building the mapping between names and dimensions at a given pyro.to_funsor
/pyro.to_data
call.
[13]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
for i in pyro.markov(range(10), history=2):
x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]: torch.Size([2, 1, 1, 1, 1])
Shape of x[1]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]: torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[3]: torch.Size([2, 1, 1, 1, 1])
Shape of x[4]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[5]: torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[6]: torch.Size([2, 1, 1, 1, 1])
Shape of x[7]: torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]: torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[9]: torch.Size([2, 1, 1, 1, 1])
Use cases beyond enumeration: global and visible dimensions¶
Global dimensions¶
It is sometimes useful to have dimensions and variables ignore the pyro.markov
structure of a program and remain active in arbitrarily deeply nested markov
and named
contexts. For example, suppose we wanted to draw a batch of samples from a Pyro model’s joint distribution. To accomplish this we indicate to pyro.to_data
that a dimension should be treated as “global” (DimType.GLOBAL
) via the dim_type
keyword argument.
[14]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
funsor_particle_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
tensor_particle_ids = pyro.to_data(funsor_particle_ids, dim_type=DimType.GLOBAL)
print("New global dimension: ", funsor_particle_ids.inputs, tensor_particle_ids.shape)
New global dimension: OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])
pyro.markov
does the hard work of automatically managing local dimensions, but because global dimensions ignore this structure, they must be deallocated manually or they will persist until the last active effect handler exits, just as global variables in Python persist until a program execution finishes.
[15]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
funsor_plate1_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate1=funsor.Bint[10]))
tensor_plate1_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
print("New global dimension: ", funsor_plate1_ids.inputs, tensor_plate1_ids.shape)
funsor_plate2_ids = funsor.Tensor(torch.arange(9), OrderedDict(plate2=funsor.Bint[9]))
tensor_plate2_ids = pyro.to_data(funsor_plate2_ids, dim_type=DimType.GLOBAL)
print("Another new global dimension: ", funsor_plate2_ids.inputs, tensor_plate2_ids.shape)
del _DIM_STACK.global_frame["plate1"]
funsor_plate3_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate3=funsor.Bint[10]))
tensor_plate3_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
print("A third new global dimension after recycling: ", funsor_plate3_ids.inputs, tensor_plate3_ids.shape)
New global dimension: OrderedDict([('plate1', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])
Another new global dimension: OrderedDict([('plate2', Bint[9, ])]) torch.Size([9, 1, 1, 1, 1, 1])
A third new global dimension after recycling: OrderedDict([('plate3', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])
Performing this deallocation directly is often unnecessary, and we include this interaction primarily to illuminate the internals of pyro.contrib.funsor
. Instead, effect handlers that introduce global dimensions, like pyro.plate
, may inherit from the GlobalNamedMessenger
effect handler which deallocates global dimensions generically upon entry and exit. We will see an example of this in the next tutorial.
Visible dimensions¶
We might also wish to preserve the meaning of the shape of a tensor of data. For this we indicate to pyro.to_data
that a dimension should be treated as not merely global but “visible” (DimTypes.VISIBLE
). By default, the 4 rightmost batch dimensions are reserved as “visible” dimensions, but this can be changed by setting the first_available_dim
attribute of the global state object _DIM_STACK
.
Users who have come across pyro.infer.TraceEnum_ELBO
’s max_plate_nesting
argument are already familiar with this distinction.
[16]:
prev_first_available_dim = _DIM_STACK.set_first_available_dim(-2)
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
funsor_local_ids = funsor.Tensor(torch.arange(9), OrderedDict(k=funsor.Bint[9]))
tensor_local_ids = pyro.to_data(funsor_local_ids, dim_type=DimType.LOCAL)
print("Tensor with new local dimension: ", funsor_local_ids.inputs, tensor_local_ids.shape)
funsor_global_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
tensor_global_ids = pyro.to_data(funsor_global_ids, dim_type=DimType.GLOBAL)
print("Tensor with new global dimension: ", funsor_global_ids.inputs, tensor_global_ids.shape)
funsor_data_ids = funsor.Tensor(torch.arange(11), OrderedDict(m=funsor.Bint[11]))
tensor_data_ids = pyro.to_data(funsor_data_ids, dim_type=DimType.VISIBLE)
print("Tensor with new visible dimension: ", funsor_data_ids.inputs, tensor_data_ids.shape)
# we also need to reset the first_available_dim after we're done
_DIM_STACK.set_first_available_dim(prev_first_available_dim)
Tensor with new local dimension: OrderedDict([('k', Bint[9, ])]) torch.Size([9, 1])
Tensor with new global dimension: OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1])
Tensor with new visible dimension: OrderedDict([('m', Bint[11, ])]) torch.Size([11])
[16]:
-5
Visible dimensions are also global and must therefore be deallocated manually or they will persist until the last effect handler exits, as in the previous example. You may be thinking by now that Funsor’s dimension names behave sort of like Python variables, with scopes and persistent meanings across expressions; indeed, this observation is the key insight behind the design of Funsor.
Fortunately, interacting directly with the dimension allocator is almost always unnecessary, and as in the previous section we include it here only to illuminate the inner workings of pyro.contrib.funsor
; rather, effect handlers like pyro.handlers.enum
that may introduce non-visible dimensions that could conflict with visible dimensions should inherit from the base pyro.contrib.funsor.handlers.named_messenger.NamedMessenger
effect handler.
However, building a bit of intuition for the inner workings of the dimension allocator will make it easier to use the new primitives in contrib.funsor
to build powerful new custom inference engines. We will see an example of one such inference engine in the next tutorial.