Modules in Pyro

This tutorial introduces PyroModule, Pyro’s Bayesian extension of PyTorch’s nn.Module class. Before starting you should understand the basics of Pyro models and inference, understand the two primitives pyro.sample() and pyro.param(), and understand the basics of Pyro’s effect handlers (e.g. by browsing minipyro.py).

Summary:

  • PyroModules are like nn.Modules but allow Pyro effects for sampling and constraints.

  • PyroModule is a mixin subclass of nn.Module that overrides attribute access (e.g. .__getattr__()).

  • There are three different ways to create a PyroModule:

    • create a new subclass: class MyModule(PyroModule): ...,

    • Pyro-ize an existing class: MyModule = PyroModule[OtherModule], or

    • Pyro-ize an existing nn.Module instance in-place: to_pyro_module_(my_module).

  • Usual nn.Parameter attributes of a PyroModule become Pyro parameters.

  • Parameters of a PyroModule synchronize with Pyro’s global param store.

  • You can add constrained parameters by creating PyroParam objects.

  • You can add stochastic attributes by creating PyroSample objects.

  • Parameters and stochastic attributes are named automatically (no string required).

  • PyroSample attributes are sampled once per .__call__() of the outermost PyroModule.

  • To enable Pyro effects on methods other than .__call__(), decorate them with @pyro_method.

  • A PyroModule model may contain nn.Module attributes.

  • An nn.Module model may contain at most one PyroModule attribute (see naming section).

  • An nn.Module may contain both a PyroModule model and PyroModule guide (e.g. Predictive).

Table of Contents

[1]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.0')

How PyroModule works

PyroModule aims to combine Pyro’s primitives and effect handlers with PyTorch’s nn.Module idiom, thereby enabling Bayesian treatment of existing nn.Modules and enabling model serving via jit.trace_module. Before you start using PyroModules it will help to understand how they work, so you can avoid pitfalls.

PyroModule is a subclass of nn.Module. PyroModule enables Pyro effects by inserting effect handling logic on module attribute access, overriding the .__getattr__(), .__setattr__(), and .__delattr__() methods. Additionally, because some effects (like sampling) apply only once per model invocation, PyroModule overrides the .__call__() method to ensure samples are generated at most once per .__call__() invocation (note nn.Module subclasses typically implement a .forward() method that is called by .__call__()).

How to create a PyroModule

There are three ways to create a PyroModule. Let’s start with a nn.Module that is not a PyroModule:

[2]:
class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight

linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The first way to create a PyroModule is to create a subclass of PyroModule. You can update any nn.Module you’ve written to be a PyroModule, e.g.

- class Linear(nn.Module):
+ class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
          self.bias = ...
      ...

Alternatively if you want to use third-party code like the Linear above you can subclass it, using PyroModule as a mixin class

[3]:
class PyroLinear(Linear, PyroModule):
    pass

linear = PyroLinear(5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The second way to create a PyroModule is to use bracket syntax PyroModule[-] to automatically denote a trivial mixin class as above.

- linear = Linear(5, 2)
+ linear = PyroModule[Linear](5, 2)

In our case we can write

[4]:
linear = PyroModule[Linear](5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The one difference between manual subclassing and using PyroModule[-] is that PyroModule[-] also ensures all nn.Module superclasses also become PyroModules, which is important for class hierarchies in library code. For example since nn.GRU is a subclass of nn.RNN, also PyroModule[nn.GRU] will be a subclass of PyroModule[nn.RNN].

The third way to create a PyroModule is to change the type of an existing nn.Module instance in-place using to_pyro_module_(). This is useful if you’re using a third-party module factory helper or updating an existing script, e.g.

[5]:
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

to_pyro_module_(linear)  # this operates in-place
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

How effects work

So far we’ve created PyroModules but haven’t made use of Pyro effects. But already the nn.Parameter attributes of our PyroModules act like pyro.param statements: they synchronize with Pyro’s param store, and they can be recorded in traces.

[6]:
pyro.clear_param_store()

# This is not traced:
linear = Linear(5, 2)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

# Now this is traced:
to_pyro_module_(linear)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))
Linear
[]
[]
PyroLinear
['bias', 'weight']
['bias', 'weight']

How to constrain parameters

Pyro parameters allow constraints, and often we want our nn.Module parameters to obey constraints. You can constrain a PyroModule’s parameters by replacing nn.Parameter with a PyroParam attribute. For example to ensure the .bias attribute is positive, we can set it to

[7]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)
print("params after:", [name for name, _ in linear.named_parameters()])
print("bias:", linear.bias)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias']
params after: ['weight', 'bias_unconstrained']
bias: tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)

Now PyTorch will optimize the .bias_unconstrained parameter, and each time we access the .bias attribute it will read and transform the .bias_unconstrained parameter (similar to a Python @property).

If you know the constraint beforehand, you can build it into the module constructor, e.g.

  class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
-         self.bias = nn.Parameter(torch.randn(out_size))
+         self.bias = PyroParam(torch.randn(out_size).exp(),
+                               constraint=constraints.positive)
      ...

How to make a PyroModule Bayesian

So far our Linear module is still deterministic. To make it randomized and Bayesian, we’ll replace nn.Parameter and PyroParam attributes with PyroSample attributes, specifying a prior. Let’s put a simple prior over the weights, taking care to expand its shape to [5,2] and declare event dimensions with .to_event() (as explained in the tensor shapes tutorial).

[8]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.weight = PyroSample(dist.Normal(0, 1).expand([5, 2]).to_event(2))
print("params after:", [name for name, _ in linear.named_parameters()])
print("weight:", linear.weight)
print("weight:", linear.weight)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias_unconstrained']
params after: ['bias_unconstrained']
weight: tensor([[-0.8668, -0.0150],
        [ 3.4642,  1.9076],
        [ 0.4717,  1.0565],
        [-1.2032,  1.0821],
        [-0.1712,  0.4711]])
weight: tensor([[-1.2577, -0.5242],
        [-0.7785, -1.0806],
        [ 0.6239, -0.4884],
        [-0.2580, -1.2288],
        [-0.7540, -1.9375]])

Notice that the .weight parameter now disappears, and each time we call linear() a new weight is sampled from the prior. In fact, the weight is sampled when the Linear.forward() accesses the .weight attribute: this attribute now has the special behavior of sampling from the prior.

We can see all the Pyro effects that appear in the trace:

[9]:
with poutine.trace() as tr:
    linear(example_input)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"])
param bias tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)
sample weight tensor([[ 1.8043,  1.5494],
        [ 0.0128,  1.4100],
        [-0.2155,  0.6375],
        [ 1.1202,  1.9672],
        [-0.1576, -0.6957]])

So far we’ve modified a third-party module to be Bayesian

linear = Linear(...)
to_pyro_module_(linear)
linear.bias = PyroParam(...)
linear.weight = PyroSample(...)

If you are creating a model from scratch, you could instead define a new class

[10]:
class BayesianLinear(PyroModule):
    def __init__(self, in_size, out_size):
       super().__init__()
       self.bias = PyroSample(
           prior=dist.LogNormal(0, 1).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))

    def forward(self, input):
        return self.bias + input @ self.weight  # this line samples bias and weight

Note that samples are drawn at most once per .__call__() invocation, for example

class BayesianLinear(PyroModule):
    ...
    def forward(self, input):
        weight1 = self.weight      # Draws a sample.
        weight2 = self.weight      # Reads previous sample.
        assert weight2 is weight1  # All accesses should agree.
        ...

⚠ Caution: accessing attributes inside plates

Because PyroSample and PyroParam attributes are modified by Pyro effects, we need to take care where parameters are accessed. For example pyro.plate contexts can change the shape of sample and param sites. Consider a model with one latent variable and a batched observation statement. We see that the only difference between these two models is where the .loc attribute is accessed.

[11]:
class NormalModel(PyroModule):
    def __init__(self):
        super().__init__()
        self.loc = PyroSample(dist.Normal(0, 1))

class GlobalModel(NormalModel):
    def forward(self, data):
        # If .loc is accessed (for the first time) outside the plate,
        # then it will have empty shape ().
        loc = self.loc
        assert loc.shape == ()
        with pyro.plate("data", len(data)):
            pyro.sample("obs", dist.Normal(loc, 1), obs=data)

class LocalModel(NormalModel):
    def forward(self, data):
        with pyro.plate("data", len(data)):
            # If .loc is accessed (for the first time) inside the plate,
            # then it will be expanded by the plate to shape (plate.size,).
            loc = self.loc
            assert loc.shape == (len(data),)
            pyro.sample("obs", dist.Normal(loc, 1), obs=data)

data = torch.randn(10)
LocalModel()(data)
GlobalModel()(data)

How to create a complex nested PyroModule

To perform inference with the above BayesianLinear module we’ll need to wrap it in probabilistic model with a likelihood; that wrapper will also be a PyroModule.

[12]:
class Model(PyroModule):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear = BayesianLinear(in_size, out_size)  # this is a PyroModule
        self.obs_scale = PyroSample(dist.LogNormal(0, 1))

    def forward(self, input, output=None):
        obs_loc = self.linear(input)  # this samples linear.bias and linear.weight
        obs_scale = self.obs_scale    # this samples self.obs_scale
        with pyro.plate("instances", len(input)):
            return pyro.sample("obs", dist.Normal(obs_loc, obs_scale).to_event(1),
                               obs=output)

Whereas a usual nn.Module can be trained with a simple PyTorch optimizer, a Pyro model requires probabilistic inference, e.g. using SVI and an AutoNormal guide. See the bayesian regression tutorial for details.

[13]:
%%time
pyro.clear_param_store()
pyro.set_rng_seed(1)

model = Model(5, 2)
x = torch.randn(100, 5)
y = model(x)

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
for step in range(2 if smoke_test else 501):
    loss = svi.step(x, y) / y.numel()
    if step % 100 == 0:
        print("step {} loss = {:0.4g}".format(step, loss))
step 0 loss = 7.186
step 100 loss = 2.185
step 200 loss = 1.87
step 300 loss = 1.739
step 400 loss = 1.691
step 500 loss = 1.673
CPU times: user 2.35 s, sys: 24.8 ms, total: 2.38 s
Wall time: 2.39 s

PyroSample statements may also depend on other sample statements or parameters. In this case the prior can be a callable depending on self, rather than a constant distribution. For example consider the hierarchical model

[14]:
class Model(PyroModule):
    def __init__(self):
        super().__init__()
        self.dof = PyroSample(dist.Gamma(3, 1))
        self.loc = PyroSample(dist.Normal(0, 1))
        self.scale = PyroSample(lambda self: dist.InverseGamma(self.dof, 1))
        self.x = PyroSample(lambda self: dist.Normal(self.loc, self.scale))

    def forward(self):
        return self.x

Model()()
[14]:
tensor(0.5387)

How naming works

In the above code we saw a BayesianLinear model embedded inside another Model. Both were PyroModules. Whereas simple pyro.sample statements require name strings, PyroModule attributes handle naming automatically. Let’s see how that works with the above model and guide (since AutoNormal is also a PyroModule).

Let’s trace executions of the model and the guide.

[15]:
with poutine.trace() as tr:
    model(x)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"].shape)
sample linear.bias torch.Size([2])
sample linear.weight torch.Size([5, 2])
sample obs_scale torch.Size([])
sample instances torch.Size([100])
sample obs torch.Size([100, 2])

Observe that model.linear.bias corresponds to the linear.bias name, and similarly for the model.linear.weight and model.obs_scale attributes. The “instances” site corresponds to the plate, and the “obs” site corresponds to the likelihood. Next examine the guide:

[16]:
with poutine.trace() as tr:
    guide(x)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"].shape)
param AutoNormal.locs.linear.bias torch.Size([2])
param AutoNormal.scales.linear.bias torch.Size([2])
sample linear.bias_unconstrained torch.Size([2])
sample linear.bias torch.Size([2])
param AutoNormal.locs.linear.weight torch.Size([5, 2])
param AutoNormal.scales.linear.weight torch.Size([5, 2])
sample linear.weight_unconstrained torch.Size([5, 2])
sample linear.weight torch.Size([5, 2])
param AutoNormal.locs.obs_scale torch.Size([])
param AutoNormal.scales.obs_scale torch.Size([])
sample obs_scale_unconstrained torch.Size([])
sample obs_scale torch.Size([])

We see the guide learns posteriors over three random variables: linear.bias, linear.weight, and obs_scale. For each of these, the guide learns a (loc,scale) pair of parameters, which are stored internally in nested PyroModules:

class AutoNormal(...):
    def __init__(self, ...):
        self.locs = PyroModule()
        self.scales = PyroModule()
        ...

Finally, AutoNormal contains a pyro.sample statement for each unconstrained latent site followed by a pyro.deterministic statement to map the unconstrained sample to a constrained posterior sample.

⚠ Caution: avoiding duplicate names

PyroModules name their attributes automatically, event for attributes nested deeply in other PyroModules. However care must be taken when mixing usual nn.Modules with PyroModules, because nn.Modules do not support automatic site naming.

Within a single model (or guide):

If there is only a single PyroModule, then your are safe.

  class Model(nn.Module):        # not a PyroModule
      def __init__(self):
          self.x = PyroModule()
-         self.y = PyroModule()  # Could lead to name conflict.
+         self.y = nn.Module()  # Has no Pyro names, so avoids conflict.

If there are only two PyroModules then one must be an attribute of the other.

class Model(PyroModule):
    def __init__(self):
       self.x = PyroModule()  # ok

If you have two PyroModules that are not attributes of each other, then they must be connected by attribute links through other PyroModules. These can be sibling links

- class Model(nn.Module):     # Could lead to name conflict.
+ class Model(PyroModule):    # Ensures names are unique.
      def __init__(self):
          self.x = PyroModule()
          self.y = PyroModule()

or ancestor links

  class Model(PyroModule):
      def __init__(self):
-         self.x = nn.Module()    # Could lead to name conflict.
+         self.x = PyroModule()   # Ensures y is conected to root Model.
          self.x.y = PyroModule()

Sometimes you may want to store a (model,guide) pair in a single nn.Module, e.g. to serve them from C++. In this case it is safe to make them attributes of a container nn.Module, but that container should not be a PyroModule.

class Container(nn.Module):            # This cannot be a PyroModule.
    def __init__(self, model, guide):  # These may be PyroModules.
        super().__init__()
        self.model = model
        self.guide = guide
    # This is a typical trace-replay pattern seen in model serving.
    def forward(self, data):
        tr = poutine.trace(self.guide).get_trace(data)
        return poutine.replay(model, tr)(data)
[ ]: