pyro.contrib.funsor, a new backend for Pyro - New primitives (Part 1)

Introduction

In this tutorial we’ll cover the basics of pyro.contrib.funsor, a new backend for the Pyro probabilistic programming system that is intended to replace the current internals of Pyro and significantly expand its capabilities as both a modelling tool and an inference research platform.

This tutorial is aimed at readers interested in developing custom inference algorithms and understanding Pyro’s current and future internals. As such, the material here assumes some familiarity with the generic Pyro API package pyroapi and with Funsor. Additional documentation for Funsor can be found on the Pyro website, on GitHub, and in the research paper “Functional Tensors for Probabilistic Programming.” Those who are less interested in such details should find that they can already use the general-purpose algorithms in contrib.funsor with their existing Pyro models via pyroapi.

Reinterpreting existing Pyro models with pyroapi

The new backend uses the pyroapi package to integrate with existing Pyro code.

First, we import some dependencies:

[1]:
from collections import OrderedDict

import torch
import funsor
from pyro import set_rng_seed as pyro_set_rng_seed

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

Importing pyro.contrib.funsor registers the "contrib.funsor" backend with pyroapi, which can now be passed as an argument to the pyroapi.pyro_backend context manager.

[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import handlers, infer, ops, optim, pyro
from pyroapi import distributions as dist

# this is already done in pyro.contrib.funsor, but we repeat it here
pyroapi.register_backend("contrib.funsor", dict(
    distributions="pyro.distributions",
    handlers="pyro.contrib.funsor.handlers",
    infer="pyro.contrib.funsor.infer",
    ops="torch",
    optim="pyro.optim",
    pyro="pyro.contrib.funsor",
))

And we’re off! From here on, any pyro.(...) statement should be understood as dispatching to the new backend.

Two new primitives: to_funsor and to_data

The first and most important new concept in pyro.contrib.funsor is the new pair of primitives pyro.to_funsor and pyro.to_data.

These are effectful versions of funsor.to_funsor and funsor.to_data, i.e. versions whose behavior can be intercepted, controlled, or used to trigger side effects by Pyro’s library of algebraic effect handlers. Let’s briefly review these two underlying functions before diving into the effectful versions in pyro.contrib.funsor.

As one might expect from the name, to_funsor takes as inputs objects that are not funsor.Funsors and attempts to convert them into Funsor terms. For example, calling funsor.to_funsor on a Python number converts it to a funsor.terms.Number object:

[3]:
funsor_one = funsor.to_funsor(float(1))
print(funsor_one, type(funsor_one))

funsor_two = funsor.to_funsor(torch.tensor(2.))
print(funsor_two, type(funsor_two))
1.0 <class 'funsor.terms.Number'>
tensor(2.) <class 'funsor.tensor.Tensor'>

Similarly ,calling funsor.to_data on an atomic funsor.Funsor converts it to a regular Python object like a float or a torch.Tensor:

[4]:
data_one = funsor.to_data(funsor.terms.Number(float(1), 'real'))
print(data_one, type(data_one))

data_two = funsor.to_data(funsor.Tensor(torch.tensor(2.), OrderedDict(), 'real'))
print(data_two, type(data_two))
1.0 <class 'float'>
tensor(2.) <class 'torch.Tensor'>

In many cases it is necessary to provide an output type to uniquely convert a piece of data to a funsor.Funsor. This also means that, strictly speaking, funsor.to_funsor and funsor.to_data are not inverses. For example, funsor.to_funsor will automatically convert Python strings to funsor.Variables, but only when given an output funsor.domains.Domain, which serves as the type of the variable:

[5]:
var_x = funsor.to_funsor("x", output=funsor.Reals[2])
print(var_x, var_x.inputs, var_x.output)
x OrderedDict([('x', Reals[2])]) Reals[2]

However, it is often impossible to convert objects to and from Funsor expressions uniquely without additional type information about inputs, as in the following example of a torch.Tensor, which could be converted to a funsor.Tensor in several ways.

To resolve this ambiguity, we need to provide to_funsor and to_data with type information that describes how to convert positional dimensions to and from unordered named Funsor dimensions. This information comes in the form of dictionaries mapping batch dimensions to dimension names or vice versa.

A key property of these mappings is that use the convention that dimension indices refer to batch dimensions, or dimensions not included in the output shape, which is treated as referring to the rightmost portion of the underlying PyTorch tensor shape, as illustrated in the example below.

[6]:
ambiguous_tensor = torch.zeros((3, 1, 2))
print("Ambiguous tensor: shape = {}".format(ambiguous_tensor.shape))

# case 1: treat all dimensions as output/event dimensions
funsor1 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[3, 1, 2])
print("Case 1: inputs = {}, output = {}".format(funsor1.inputs, funsor1.output))

# case 2: treat the leftmost dimension as a batch dimension
# note that dimension -1 in dim_to_name here refers to the rightmost *batch dimension*,
# i.e. dimension -3 of ambiguous_tensor, the rightmost dimension not included in the output shape.
funsor2 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[1, 2], dim_to_name={-1: "a"})
print("Case 2: inputs = {}, output = {}".format(funsor2.inputs, funsor2.output))

# case 3: treat the leftmost 2 dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1 and -2 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor3 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[2], dim_to_name={-1: "b", -2: "a"})
print("Case 3: inputs = {}, output = {}".format(funsor3.inputs, funsor3.output))

# case 4: treat all dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1, -2 and -3 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -1, -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor4 = funsor.to_funsor(ambiguous_tensor, output=funsor.Real, dim_to_name={-1: "c", -2: "b", -3: "a"})
print("Case 4: inputs = {}, output = {}".format(funsor4.inputs, funsor4.output))
Ambiguous tensor: shape = torch.Size([3, 1, 2])
Case 1: inputs = OrderedDict(), output = Reals[3,1,2]
Case 2: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[1,2]
Case 3: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[2]
Case 4: inputs = OrderedDict([('a', Bint[3, ]), ('c', Bint[2, ])]), output = Real

Similar ambiguity exists for to_data: the inputs of a funsor.Funsor are ordered arbitrarily, and empty dimensions in the data are squeezed away, so a mapping from names to batch dimensions must be provided to ensure unique conversion:

[7]:
ambiguous_funsor = funsor.Tensor(torch.zeros((3, 2)), OrderedDict(a=funsor.Bint[3], b=funsor.Bint[2]), 'real')
print("Ambiguous funsor: inputs = {}, shape = {}".format(ambiguous_funsor.inputs, ambiguous_funsor.output))

# case 1: the simplest version
tensor1 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -2, "b": -1})
print("Case 1: shape = {}".format(tensor1.shape))

# case 2: an empty dimension between a and b
tensor2 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -3, "b": -1})
print("Case 2: shape = {}".format(tensor2.shape))

# case 3: permuting the input dimensions
tensor3 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -1, "b": -2})
print("Case 3: shape = {}".format(tensor3.shape))
Ambiguous funsor: inputs = OrderedDict([('a', Bint[3, ]), ('b', Bint[2, ])]), shape = Real
Case 1: shape = torch.Size([3, 2])
Case 2: shape = torch.Size([3, 1, 2])
Case 3: shape = torch.Size([2, 3])

Maintaining and updating this information efficiently becomes tedious and error-prone as the number of conversions increases. Fortunately, it can be automated away completely. Consider the following example:

[8]:
name_to_dim = OrderedDict()

funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
name_to_dim.update({"x": -1})
tensor_x = funsor.to_data(funsor_x, name_to_dim=name_to_dim)
print(name_to_dim, funsor_x.inputs, tensor_x.shape)

funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
name_to_dim.update({"y": -2})
tensor_y = funsor.to_data(funsor_y, name_to_dim=name_to_dim)
print(name_to_dim, funsor_y.inputs, tensor_y.shape)

funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
name_to_dim.update({"z": -3})
tensor_z = funsor.to_data(funsor_z, name_to_dim=name_to_dim)
print(name_to_dim, funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', -1)]) OrderedDict([('x', Bint[2, ])]) torch.Size([2])
OrderedDict([('x', -1), ('y', -2)]) OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2])
OrderedDict([('x', -1), ('y', -2), ('z', -3)]) OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1])

This is exactly the functionality provided by pyro.to_funsor and pyro.to_data, as we can see by using them in the previous example and removing the manual updates. We must also wrap the function in a handlers.named effect handler to ensure that the dimension dictionaries do not persist beyond the function body.

[9]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
    tensor_x = pyro.to_data(funsor_x)
    print(funsor_x.inputs, tensor_x.shape)

    funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
    tensor_y = pyro.to_data(funsor_y)
    print(funsor_y.inputs, tensor_y.shape)

    funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
    tensor_z = pyro.to_data(funsor_z)
    print(funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', Bint[2, ])]) torch.Size([2, 1, 1, 1, 1])
OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2, 1, 1, 1, 1])
OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1, 1, 1, 1, 1])

Critically, pyro.to_funsor and pyro.to_data use and update the same bidirectional mapping between names and dimensions, allowing them to be combined intuitively. A typical usage pattern, and one that pyro.contrib.funsor uses heavily in its inference algorithm implementations, is to create a funsor.Funsor term directly with a new named dimension and call pyro.to_data on it, perform some PyTorch computations, and call pyro.to_funsor on the result:

[10]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    probs = funsor.Tensor(torch.tensor([0.5, 0.4, 0.7]), OrderedDict(batch=funsor.Bint[3]))
    print(type(probs), probs.inputs, probs.output)

    x = funsor.Tensor(torch.tensor([0., 1., 0., 1.]), OrderedDict(x=funsor.Bint[4]))
    print(type(x), x.inputs, x.output)

    dx = dist.Bernoulli(pyro.to_data(probs))
    print(type(dx), dx.shape())

    px = pyro.to_funsor(dx.log_prob(pyro.to_data(x)), output=funsor.Real)
    print(type(px), px.inputs, px.output)
<class 'funsor.tensor.Tensor'> OrderedDict([('batch', Bint[3, ])]) Real
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ])]) Real
<class 'pyro.distributions.torch.Bernoulli'> torch.Size([3, 1, 1, 1, 1])
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ]), ('batch', Bint[3, ])]) Real

pyro.to_funsor and pyro.to_data treat the keys in their name-to-dim mappings as references to the input’s batch shape, but treats the values as references to the globally consistent name-dim mapping. This may be useful for complicated computations that involve a mixture of PyTorch and Funsor operations.

[11]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"})
    print("x: ", type(x), x.inputs, x.output)

    px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"})
    print("px: ", type(px), px.inputs, px.output)
x:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ])]) Real
px:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ]), ('y', Bint[3, ])]) Real

Dealing with large numbers of variables: (re-)introducing pyro.markov

So far, so good. However, what if the number of different named dimensions continues to increase? We face two problems: first, reusing the fixed number of available positional dimensions (25 in PyTorch), and second, computing shape information with time complexity that is independent of the number of variables.

A fully general automated solution to this problem would require deeper integration with Python or PyTorch. Instead, as an intermediate solution, we introduce the second key concept in pyro.contrib.funsor: the pyro.markov annotation, a way to indicate the shelf life of certain variables. pyro.markov is already part of Pyro (see enumeration tutorial) but the implementation in pyro.contrib.funsor is fresh.

The primary constraint on the design of pyro.markov is backwards compatibility: in order for pyro.contrib.funsor to be compatible with the large range of existing Pyro models, the new implementation had to match the shape semantics of Pyro’s existing enumeration machinery as closely as possible.

[12]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10)):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
        print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[3]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[4]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[9]:  torch.Size([2, 1, 1, 1, 1, 1])

pyro.markov is a versatile piece of syntax that can be used as a context manager, a decorator, or an iterator. It is important to understand that pyro.markov’s only functionality at present is tracking variable usage, not directly indicating conditional independence properties to inference algorithms, and as such it is only necessary to add enough annotations to ensure that tensors have correct shapes, rather than attempting to manually encode as much dependency information as possible.

pyro.markov takes an additional argument history that determines the number of previous pyro.markov contexts to take into account when building the mapping between names and dimensions at a given pyro.to_funsor/pyro.to_data call.

[13]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10), history=2):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
        print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[3]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[4]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[9]:  torch.Size([2, 1, 1, 1, 1])

Use cases beyond enumeration: global and visible dimensions

Global dimensions

It is sometimes useful to have dimensions and variables ignore the pyro.markov structure of a program and remain active in arbitrarily deeply nested markov and named contexts. For example, suppose we wanted to draw a batch of samples from a Pyro model’s joint distribution. To accomplish this we indicate to pyro.to_data that a dimension should be treated as “global” (DimType.GLOBAL) via the dim_type keyword argument.

[14]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_particle_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
    tensor_particle_ids = pyro.to_data(funsor_particle_ids, dim_type=DimType.GLOBAL)
    print("New global dimension: ", funsor_particle_ids.inputs, tensor_particle_ids.shape)
New global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])

pyro.markov does the hard work of automatically managing local dimensions, but because global dimensions ignore this structure, they must be deallocated manually or they will persist until the last active effect handler exits, just as global variables in Python persist until a program execution finishes.

[15]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    funsor_plate1_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate1=funsor.Bint[10]))
    tensor_plate1_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
    print("New global dimension: ", funsor_plate1_ids.inputs, tensor_plate1_ids.shape)

    funsor_plate2_ids = funsor.Tensor(torch.arange(9), OrderedDict(plate2=funsor.Bint[9]))
    tensor_plate2_ids = pyro.to_data(funsor_plate2_ids, dim_type=DimType.GLOBAL)
    print("Another new global dimension: ", funsor_plate2_ids.inputs, tensor_plate2_ids.shape)

    del _DIM_STACK.global_frame["plate1"]

    funsor_plate3_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate3=funsor.Bint[10]))
    tensor_plate3_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
    print("A third new global dimension after recycling: ", funsor_plate3_ids.inputs, tensor_plate3_ids.shape)
New global dimension:  OrderedDict([('plate1', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])
Another new global dimension:  OrderedDict([('plate2', Bint[9, ])]) torch.Size([9, 1, 1, 1, 1, 1])
A third new global dimension after recycling:  OrderedDict([('plate3', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])

Performing this deallocation directly is often unnecessary, and we include this interaction primarily to illuminate the internals of pyro.contrib.funsor. Instead, effect handlers that introduce global dimensions, like pyro.plate, may inherit from the GlobalNamedMessenger effect handler which deallocates global dimensions generically upon entry and exit. We will see an example of this in the next tutorial.

Visible dimensions

We might also wish to preserve the meaning of the shape of a tensor of data. For this we indicate to pyro.to_data that a dimension should be treated as not merely global but “visible” (DimTypes.VISIBLE). By default, the 4 rightmost batch dimensions are reserved as “visible” dimensions, but this can be changed by setting the first_available_dim attribute of the global state object _DIM_STACK.

Users who have come across pyro.infer.TraceEnum_ELBO’s max_plate_nesting argument are already familiar with this distinction.

[16]:
prev_first_available_dim = _DIM_STACK.set_first_available_dim(-2)

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    funsor_local_ids = funsor.Tensor(torch.arange(9), OrderedDict(k=funsor.Bint[9]))
    tensor_local_ids = pyro.to_data(funsor_local_ids, dim_type=DimType.LOCAL)
    print("Tensor with new local dimension: ", funsor_local_ids.inputs, tensor_local_ids.shape)

    funsor_global_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
    tensor_global_ids = pyro.to_data(funsor_global_ids, dim_type=DimType.GLOBAL)
    print("Tensor with new global dimension: ", funsor_global_ids.inputs, tensor_global_ids.shape)

    funsor_data_ids = funsor.Tensor(torch.arange(11), OrderedDict(m=funsor.Bint[11]))
    tensor_data_ids = pyro.to_data(funsor_data_ids, dim_type=DimType.VISIBLE)
    print("Tensor with new visible dimension: ", funsor_data_ids.inputs, tensor_data_ids.shape)

# we also need to reset the first_available_dim after we're done
_DIM_STACK.set_first_available_dim(prev_first_available_dim)
Tensor with new local dimension:  OrderedDict([('k', Bint[9, ])]) torch.Size([9, 1])
Tensor with new global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1])
Tensor with new visible dimension:  OrderedDict([('m', Bint[11, ])]) torch.Size([11])
[16]:
-5

Visible dimensions are also global and must therefore be deallocated manually or they will persist until the last effect handler exits, as in the previous example. You may be thinking by now that Funsor’s dimension names behave sort of like Python variables, with scopes and persistent meanings across expressions; indeed, this observation is the key insight behind the design of Funsor.

Fortunately, interacting directly with the dimension allocator is almost always unnecessary, and as in the previous section we include it here only to illuminate the inner workings of pyro.contrib.funsor; rather, effect handlers like pyro.handlers.enum that may introduce non-visible dimensions that could conflict with visible dimensions should inherit from the base pyro.contrib.funsor.handlers.named_messenger.NamedMessenger effect handler.

However, building a bit of intuition for the inner workings of the dimension allocator will make it easier to use the new primitives in contrib.funsor to build powerful new custom inference engines. We will see an example of one such inference engine in the next tutorial.