Tensor shapes in Pyro 0.2¶
This tutorial introduces Pyro’s organization of tensor dimensions. Before starting, you should familiarize yourself with PyTorch broadcasting semantics.
Summary:¶
 While you are learning or debugging, set
pyro.enable_validation(True)
.  Tensors broadcast by aligning on the right:
torch.ones(3,4,5) + torch.ones(5)
.  Distribution
.sample().shape == batch_shape + event_shape
.  Distribution
.log_prob(x).shape == batch_shape
(but notevent_shape
!).  Use
my_dist.expand_by([2,3,4])
to draw a batch of samples.  Use
my_dist.independent(1)
to declare a dimension as dependent.  Use
with pyro.iarange('name', size):
to declare a dimension as independent.  All dimensions must be declared either dependent or independent.
 Try to support batching on the left. This lets Pyro autoparallelize.
 use negative indices like
x.sum(1)
rather thanx.sum(2)
 use ellipsis notation like
pixel = image[..., i, j]
 use negative indices like
Table of Contents¶
 Distribution shapes
 Declaring independence with iarange
 Subsampling inside iarange
 Broadcasting to allow Parallel Enumeration
 Automatic broadcasting via broadcast poutine
In [ ]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
pyro.enable_validation(True) # < This is always a good idea!
# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
Distributions shapes: batch_shape
and event_shape
¶
PyTorch Tensor
s have a single .shape
attribute, but
Distribution
s have two shape attributions with special meaning:
.batch_shape
and .event_shape
. These two combine to define the
total shape of a sample
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
Indices over .batch_shape
denote independent random variables,
whereas indices over .event_shape
denote dependent random variables.
Because the dependent random variables define probability together, the
.log_prob()
method only produces a single number for each event of
shape .event_shape
. Thus the total shape of .log_prob()
is
.batch_shape
:
assert d.log_prob(x).shape == d.batch_shape
Note that the Distribution.sample()
method also takes a
sample_shape
parameter that indexes over independent identically
distributed (iid) random varables, so that
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
In summary
 iid  independent  dependent
+++
shape = sample_shape + batch_shape + event_shape
For example univariate distributions have empty event shape (because
each number is an independent event). Distributions over vectors like
MultivariateNormal
have len(event_shape) == 1
. Distributions
over matrices like InverseWishart
have len(event_shape) == 2
.
Examples¶
The simplest distribution shape is a single univariate distribution.
In [ ]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()
Distributions can be batched by passing in batched parameters.
In [ ]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
Another way to batch distributions is via the .expand_by()
method.
This only works if parameters are identical along the leftmost
dimensions.
In [ ]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand_by([3])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
Multivariate distributions have nonempty .event_shape
. For these
distributions, the shapes of .sample()
and .log_prob(x)
differ:
In [ ]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape
Reshaping distributions¶
In Pyro you can treat a univariate distribution as multivariate by
calling the .independent(_)
property.
In [ ]:
d = Bernoulli(0.5 * torch.ones(3,4)).independent(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)
While you work with Pyro programs, keep in mind that samples have shape
batch_shape + event_shape
, whereas .log_prob(x)
values have
shape batch_shape
. You’ll need to ensure that batch_shape
is
carefully controlled by either trimming it down with .independent(n)
or by declaring dimensions as independent via pyro.iarange
.
It is always safe to assume dependence¶
Often in Pyro we’ll declare some dimensions as dependent even though they are in fact independent, e.g.
pyro.sample("x", dist.Normal(0, 1).expand_by([10]).independent(1))
This is useful for two reasons: First it allows us to easily swap in a
MultivariateNormal
distribution later. Second it simplifies the code
a bit since we don’t need an iarange
(see below) as in
with pyro.iarange("x_iarange", 10):
pyro.sample("x", dist.Normal(0, 1).expand_by([10]))
The difference between these two versions is that the second version
with iarange
informs Pyro that it can make use of independence
information when estimating gradients, whereas in the first version Pyro
must assume they are dependent (even though the normals are in fact
independent). This is analogous to dseparation in graphical models: it
is always safe to add edges and assume variables may be dependent
(i.e. to widen the model class), but it is unsafe to assume independence
when variables are actually dependent (i.e. narrowing the model class so
the true model lies outside of the class, as in mean field). In practice
Pyro’s SVI inference algorithm uses reparameterized gradient estimators
for Normal
distributions so both gradient estimators have the same
performance.
Declaring independent dims with iarange
¶
Pyro models can use the context manager pyro.iarange to declare that certain batch dimensions are independent. Inference algorithms can then take advantage of this independence to e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.
The simplest way to declare a dimension as independent is to declare the rightmost batch dimension as independent via a simple
with pyro.iarange("my_iarange"):
# within this context, batch dimension 1 is independent
We recommend always providing an optional size argument to aid in debugging shapes
with pyro.iarange("my_iarange", len(my_data)):
# within this context, batch dimension 1 is independent
Starting with Pyro 0.2 you can additionally nest iaranges
, e.g. if
you have perpixel independence:
with pyro.iarange("x_axis", 320):
# within this context, batch dimension 1 is independent
with pyro.iarange("y_axis", 200):
# within this context, batch dimensions 2 and 1 are independent
Note that we always count from the right by using negative indices like 2, 1.
Finally if you want to mix and match iarange
s for e.g. noise that
depends only on x
, some noise that depends only on y
, and some
noise that depends on both, you can declare multiple iaranges
and
use them as reusable context managers. In this case Pyro cannot
automatically allocate a dimension, so you need to provide a dim
argument (again counting from the right):
x_axis = pyro.iarange("x_axis", 3, dim=2)
y_axis = pyro.iarange("y_axis", 2, dim=3)
with x_axis:
# within this context, batch dimension 2 is independent
with y_axis:
# within this context, batch dimension 3 is independent
with x_axis, y_axis:
# within this context, batch dimensions 3 and 2 are independent
Let’s take a closer look at batch sizes within iarange
s.
In [ ]:
def model1():
a = pyro.sample("a", Normal(0, 1))
b = pyro.sample("b", Normal(torch.zeros(2), 1).independent(1))
with pyro.iarange("c_iarange", 2):
c = pyro.sample("c", Normal(torch.zeros(2), 1))
with pyro.iarange("d_iarange", 3):
d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).independent(2))
assert a.shape == () # batch_shape == () event_shape == ()
assert b.shape == (2,) # batch_shape == () event_shape == (2,)
assert c.shape == (2,) # batch_shape == (2,) event_sahpe == ()
assert d.shape == (3,4,5) # batch_shape == (3,) event_shape == (4,5)
x_axis = pyro.iarange("x_axis", 3, dim=2)
y_axis = pyro.iarange("y_axis", 2, dim=3)
with x_axis:
x = pyro.sample("x", Normal(0, 1).expand_by([3, 1]))
with y_axis:
y = pyro.sample("y", Normal(0, 1).expand_by([2, 1, 1]))
with x_axis, y_axis:
xy = pyro.sample("xy", Normal(0, 1).expand_by([2, 3, 1]))
z = pyro.sample("z", Normal(0, 1).expand_by([2, 3, 1, 5]).independent(1))
assert x.shape == (3, 1) # batch_shape == (3,1) event_shape == ()
assert y.shape == (2, 1, 1) # batch_shape == (2,1,1) event_shape == ()
assert xy.shape == (2, 3, 1) # batch_shape == (2,3,1) event_shape == ()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2,3,1) event_shape == (5,)
test_model(model1, model1, Trace_ELBO())
It is helpful to visualize the .shape
s of each sample site by
aligning them at the boundary between batch_shape
and
event_shape
: dimensions to the right will be summed out in
.log_prob()
and dimensions to the left will remain.
batch dims  event dims
+
 a = sample("a", Normal(0, 1))
2 b = sample("b", Normal(zeros(2), 1)
 .independent(1)
 with iarange("c", 2):
2 c = sample("c", Normal(zeros(2), 1))
 with iarange("d", 3):
34 5 d = sample("d", Normal(zeros(3,4,5), 1)
 .independent(2)

 x_axis = iarange("x", 3, dim=2)
 y_axis = iarange("y", 2, dim=3)
 with x_axis:
3 1 x = sample("x", Normal(0, 1).expand_by([3, 1]))
 with y_axis:
2 1 1 y = sample("y", Normal(0, 1).expand_by([2, 1, 1]))
 with x_axis, y_axis:
2 3 1 xy = sample("xy", Normal(0, 1).expand_by([2, 3, 1]))
2 3 15 z = sample("z", Normal(0, 1).expand_by([2, 3, 1, 5])
 .independent(1))
As an exercise, try to tabulate the shapes of sample sites in one of your own programs.
Subsampling tensors inside an iarange
¶
One of the main uses of
iarange is
to subsample data. This is possible within an iarange
because data
are independent, so the expected value of the loss on, say, half the
data should be half the expected loss on the full data.
To subsample data, you need to inform Pyro of both the original data size and the subsample size; Pyro will then choose a random subset of data and yield the set of indices.
In [ ]:
data = torch.arange(100)
def model2():
mean = pyro.param("mean", torch.zeros(len(data)))
with pyro.iarange("data", len(data), subsample_size=10) as ind:
assert len(ind) == 10 # ind is a LongTensor that indexes the subsample.
batch = data[ind] # Select a minibatch of data.
mean_batch = mean[ind] # Take care to select the relevant perdatum parameters.
# Do stuff with batch:
x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
assert len(x) == 10
test_model(model2, guide=lambda: None, loss=Trace_ELBO())
Broadcasting to allow parallel enumeration¶
Pyro 0.2 introduces the ability to enumerate discrete latent variables in parallel. This can significantly reduce the variance of gradient estimators when learning a posterior via SVI.
To use discrete enumeration, Pyro needs to allocate tensor dimension
that it can use for enumeration. To avoid conflicting with other
dimensions that we want to use for iarange
s, we need to declare a
budget of the maximum number of tensor dimensions we’ll use. This budget
is called max_iarange_nesting
and is an argument to
SVI (the argument
is simply passed through to
TraceEnum_ELBO).
To understand max_iarange_nesting
and how Pyro allocates dimensions
for enumeration, let’s revisit model1()
from above. This time we’ll
map out three types of dimensions: enumeration dimensions on the left
(Pyro takes control of these), batch dimensions in the middle, and event
dimensions on the right.
max_iarange_nesting = 3
<>
enumerationbatchevent
++
. . . a = sample("a", Normal(0, 1))
. . .2 b = sample("b", Normal(zeros(2), 1)
  .independent(1))
  with iarange("c", 2):
. . 2 c = sample("c", Normal(zeros(2), 1))
  with iarange("d", 3):
. . 34 5 d = sample("d", Normal(zeros(3,4,5), 1)
  .independent(2))
 
  x_axis = iarange("x", 3, dim=2)
  y_axis = iarange("y", 2, dim=3)
  with x_axis:
. 3 1 x = sample("x", Normal(0, 1).expand_by([3,1]))
  with y_axis:
2 1 1 y = sample("y", Normal(0, 1).expand_by([2,1,1]))
  with x_axis, y_axis:
2 3 1 xy = sample("xy", Normal(0, 1).expand_by([2,3,1]))
2 3 15 z = sample("z", Normal(0, 1).expand_by([2,3,1,5]))
  .independent(1))
Note that it is safe to overprovision max_iarange_nesting=4
but we
cannot underprovision max_iarange_nesting=2
(or Pyro will error).
Let’s see how this works in practice.
In [ ]:
@config_enumerate(default="parallel")
def model3():
p = pyro.param("p", torch.arange(6) / 6)
locs = pyro.param("locs", torch.tensor([1., 1.]))
a = pyro.sample("a", Categorical(torch.ones(6) / 6))
b = pyro.sample("b", Bernoulli(p[a])) # Note this depends on a.
with pyro.iarange("c_iarange", 4):
c = pyro.sample("c", Bernoulli(0.3).expand_by([4]))
with pyro.iarange("d_iarange", 5):
d = pyro.sample("d", Bernoulli(0.4).expand_by([5,4]))
e_loc = locs[d.long()].unsqueeze(1)
e_scale = torch.arange(1, 8)
e = pyro.sample("e", Normal(e_loc, e_scale)
.independent(1)) # Note this depends on d.
# enumeratedbatchevent dims
assert a.shape == ( 6, 1, 1 ) # Six enumerated values of the Categorical.
assert b.shape == ( 2, 6, 1, 1 ) # 2 enumerated Bernoullis x 6 Categoricals.
assert c.shape == ( 2, 1, 1, 1, 4 ) # Only 2 Bernoullis; does not depend on a or b.
assert d.shape == (2, 1, 1, 1, 5, 4 ) # Only two Bernoullis.
assert e.shape == (2, 1, 1, 1, 5, 4, 7) # This is sampled and depends on d.
assert e_loc.shape == (2, 1, 1, 1, 5, 4, 1,)
assert e_scale.shape == ( 7,)
test_model(model3, model3, TraceEnum_ELBO(max_iarange_nesting=2))
Let’s take a closer look at those dimensions. First note that Pyro
allocates enumeration dims starting from the right at
max_iarange_nesting
: Pyro allocates dim 3 to enumerate a
, then
dim 4 to enumerate b
, then dim 5 to enumerate c
, and finally
dim 6 to enumerate d
. Next note that variables only have extent
(size > 1) in dimensions they depend on. This helps keep tensors small
and computation cheap. We can draw a similar map of the tensor
dimensions:
max_iarange_nesting = 2
<>
enumeration batch event

61 1 a = pyro.sample("a", Categorical(torch.ones(6) / 6))
2 11 1 b = pyro.sample("b", Bernoulli(p[a]))
  with pyro.iarange("c_iarange", 4):
2 1 11 4 c = pyro.sample("c", Bernoulli(0.3).expand_by([4]))
  with pyro.iarange("d_iarange", 5):
2 1 1 15 4 d = pyro.sample("d", Bernoulli(0.4).expand_by([5,4]))
2 1 1 15 41 e_loc = locs[d.long()].unsqueeze(1)
 7 e_scale = torch.arange(1, 8)
2 1 1 15 47 e = pyro.sample("e", Normal(e_loc, e_scale)
  .independent(1))
Writing parallelizable code¶
It can be tricky to write Pyro models that correctly handle parallelized sample sites. Two tricks help: broadcasting and ellipsis slicing. Let’s look at a contrived model to see how these work in practice. Our aim is to write a model that works both with and without enumeration.
In [ ]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None # set to either True or False below
def fun(observe):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.iarange('x_axis', width, dim=2)
y_axis = pyro.iarange('y_axis', height, dim=1)
# Note that the shapes of these sites depend on whether Pyro is enumerating.
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
if enumerated:
assert x_active.shape == (2, width, 1)
assert y_active.shape == (2, 1, 1, height)
else:
assert x_active.shape == (width, 1)
assert y_active.shape == (height,)
# The first trick is to broadcast. This works with or without enumeration.
p = 0.1 + 0.5 * x_active * y_active
if enumerated:
assert p.shape == (2, 2, width, height)
else:
assert p.shape == (width, height)
# The second trick is to index using ellipsis slicing.
# This allows Pyro to add arbitrary dimensions on the left.
dense_pixels = torch.zeros_like(p)
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
if enumerated:
assert dense_pixels.shape == (2, 2, width, height)
else:
assert dense_pixels.shape == (width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
@config_enumerate(default="parallel")
def guide4():
fun(observe=False)
# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())
# Test with enumeration.
enumerated = True
test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))
Automatic broadcasting via broadcast poutine¶
Note that in all our model/guide specifications, we had to expand sample
shapes by hand to satisfy the constraints on batch shape enforced by
pyro.iarange
statements. This code can be simplified by using
poutine.broadcast,
which automatically broadcasts the batch shape of pyro.sample
statements when inside a single or nested iarange context.
We will demonstrate this using model4
from the previous
section. Note the following changes to
the code from earlier:
 For the purpose of this example, we will only consider “parallel” enumeration, but broadcasting should work as expected without enumeration or with “sequential” enumeration.
 We have separated out the sampling function which returns the tensors
corresponding to the active pixels. Modularizing the model code into
components is a common practice, and helps with maintainability of
large models. The first sampling function is identical to what we had
in
model4
, and the remaining sampling functions usepoutine.broadcast
to implicitly expand sample sites to confirm to the shape requirements imposed by theiarange
contexts in which they are embedded.  We would also like to use the
pyro.iarange
construct to parallelize the ELBO estimator over num_particles. This is done by wrapping the contents of model/guide inside an outermostpyro.iarange
context.
In [ ]:
num_particles = 100 # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([num_particles, width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([num_particles, 1, height]))
return x_active, y_active
def sample_pixel_locations_automatic_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y))
return x_active, y_active
def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
return x_active, y_active
def fun(observe, sample_fn):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.iarange('x_axis', width, dim=2)
y_axis = pyro.iarange('y_axis', height, dim=1)
with pyro.iarange("num_particles", 100, dim=3):
x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
# Indices corresponding to "parallel" enumeration are appended
# to the left of the "num_particles" iarange dim.
assert x_active.shape == (2, num_particles, width, 1)
assert y_active.shape == (2, 1, num_particles, 1, height)
p = 0.1 + 0.5 * x_active * y_active
assert p.shape == (2, 2, num_particles, width, height)
dense_pixels = torch.zeros_like(p)
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
assert dense_pixels.shape == (2, 2, num_particles, width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def test_model_with_sample_fn(sample_fn, broadcast=False):
def model():
fun(observe=True, sample_fn=sample_fn)
@config_enumerate(default="parallel")
def guide():
fun(observe=False, sample_fn=sample_fn)
if broadcast:
model = poutine.broadcast(model)
guide = poutine.broadcast(guide)
test_model(model, guide, TraceEnum_ELBO(max_iarange_nesting=3))
test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_automatic_broadcasting, broadcast=True)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting, broadcast=True)
In the first sampling function, we had to do some manual bookkeeping
and expand the Bernoulli
distribution’s batch shape to account for
the independent dimensions added by the pyro.iarange
contexts. In
particular, note how sample_pixel_locations
needs knowledge of
num_particles
, width
and height
and is accessing these
variables from the global scope, which is not ideal.
The next two sampling functions are annotated with poutine.broadcast, so that this can be automatically achieved via an effect handler. Note the following in the next two modified sampling functions:
 The second argument to
pyro.iarange
, i.e. the optionalsize
argument needs to be provided for implicit broadasting, so thatpoutine.broadcast
can infer the batch shape requirement for each of the sample sites.  The existing
batch_shape
of the sample site must be broadcastable with the size of thepyro.iarange
contexts. In our particular example,Bernoulli(p_x)
has an empty batch shape which is universally broadcastable. poutine.broadcast
is idempotent, and is also safe to use when the sample sites have been partially broadcasted to the size of some of theiarange
s but not all. In the third sampling function, the user has partially expandedx_active
andy_active
, and the broadcast effect handler expands the other batch dimensions to the size of remainingiarange
s.
Note how simple it is to achieve parallelization via tensorized
operations using pyro.iarange
and poutine.broadcast
!
poutine.broadcast
also helps in code modularization because model
components can be written agnostic of the iarange
contexts in which
they may subsequently get embedded in.