# Attend Infer Repeat¶

In this tutorial we will implement the model and inference strategy described in “Attend, Infer, Repeat: Fast Scene Understanding with Generative Models” (AIR) [1] and apply it to the multi-mnist dataset.

A standalone implementation is also available.

```
In [ ]:
```

```
%pylab inline
from collections import namedtuple
from observations import multi_mnist
import pyro
import pyro.optim as optim
from pyro.infer import SVI
import pyro.distributions as dist
from pyro.util import ng_zeros, ng_ones
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.nn.functional import relu, sigmoid, softplus, grid_sample, affine_grid
import numpy as np
```

## Introduction¶

The model described in [1] is a generative model of scenes. In this tutorial we will use it to model images from a dataset that is similar to the multi-mnist dataset in [1]. Here are some data points from this data set:

```
In [ ]:
```

```
inpath = '../../examples/air/data'
(X_np, _), _ = multi_mnist(inpath, max_digits=2, canvas_size=50, seed=42)
X_np = X_np.astype(np.float32)
X_np /= 255.0
mnist = Variable(torch.from_numpy(X_np))
def show_images(imgs):
figure(figsize=(12,4))
for i, img in enumerate(imgs):
subplot(1, len(imgs), i + 1)
imshow(img.data.numpy(), cmap='binary')
show_images(mnist[9:14])
```

To get an idea where we’re heading, we first give a brief overview of the model and the approach we’ll take to inference. We’ll follow the naming conventions used in [1] as closely as possible.

AIR decomposes the process of generating an image into discrete steps,
each of which generates only part of the image. More specifically, at
each step the model will generate a small image (`y_att`

) by passing a
latent “code” variable (`z_what`

) through a neural network. We’ll
refer to these small images as “objects”. In the case of AIR applied to
the multi-mnist dataset we expect each of these objects to represent a
single digit. The model also includes uncertainty about the location and
size of each object. We’ll describe an object’s location and size as its
“pose” (`z_where`

). To produce the final image, each object will first
be located within a larger image (`y`

) using the pose infomation
`z_where`

. Finally, the `y`

s from all time steps will be combined
additively to produce the final image `x`

.

Here’s a picture (reproduced from [1]) that shows two steps of this process:

Inference is performed in this model using amortized stochastic variational inference (SVI). The parameters of the neural network are also optimized during inference. Performing inference in such rich models is always difficult, but the presence of discrete choices (the number of steps in this case) makes inference in this model particularly tricky. For this reason the authors use a technique called data dependent baselines to achieve good performance. This technique can be implemented in Pyro, and we’ll see how later in the tutorial.

## Model¶

### Generating a single object¶

Let’s look at the model more closely. At the core of the model is the generative process for a single object. Recall that:

- At each step a single object is generated.
- Each object is generated by passing its latent code through a neural network.
- We maintain uncertainty about the latent code used to generate each object, as well as its pose.

This can be expressed in Pyro like so:

```
In [ ]:
```

```
# Create the neural network. This takes a latent code, z_what, to pixel intensities.
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.l1 = nn.Linear(50, 200)
self.l2 = nn.Linear(200, 400)
def forward(self, z_what):
h = relu(self.l1(z_what))
return sigmoid(self.l2(h))
decode = Decoder()
z_where_prior_mu = Variable(torch.Tensor([3, 0, 0]))
z_where_prior_sigma = Variable(torch.Tensor([0.1, 1, 1]))
z_what_prior_mu = ng_zeros(50)
z_what_prior_sigma = ng_ones(50)
def model_step_sketch(t):
# Sample object pose. This is a 3-dimensional vector representing x,y position and size.
z_where = pyro.sample('z_where_{}'.format(t),
dist.normal,
z_where_prior_mu,
z_where_prior_sigma,
batch_size=1)
# Sample object code. This is a 50-dimensional vector.
z_what = pyro.sample('z_what_{}'.format(t),
dist.normal,
z_what_prior_mu,
z_what_prior_sigma,
batch_size=1)
# Map code to pixel space using the neural network.
y_att = decode(z_what)
# Position/scale object within larger image.
y = object_to_image(z_where, y_att)
return y
```

Hopefully the use of `pyro.sample`

and PyTorch networks within a model
seem familiar at this point. If not you might want to review the VAE
tutorial. One thing to note is that we include the current
step `t`

in the name passed to `pyro.sample`

to ensure that names
are unique across steps.

The `object_to_image`

function is specific to this model and warrants
further attention. Recall that the neural network (`decode`

here) will
output a small image, and that we would like to add this to the output
image after performing any translation and scaling required to achieve
the pose (location and size) described by `z_where`

. It’s not clear
how to do this, and in particular it’s not obvious that this can be
implemented in a way that preserves the differentiability of our model,
which we require in order to perform SVI. However,
it turns out we can do this this using a spatial transformer network
(STN) [2].

Happily for us, PyTorch makes it easy to implement a STN using its
grid_sample and
affine_grid
functions. `object_to_image`

is a simple function that calls these,
doing a little extra work to massage `z_where`

into the expected
format.

```
In [ ]:
```

```
def expand_z_where(z_where):
# Takes 3-dimensional vectors, and massages them into 2x3 matrices with elements like so:
# [s,x,y] -> [[s,0,x],
# [0,s,y]]
n = z_where.size(0)
expansion_indices = Variable(torch.LongTensor([1, 0, 2, 0, 1, 3]))
out = torch.cat((ng_zeros([1, 1]).expand(n, 1), z_where), 1)
return torch.index_select(out, 1, expansion_indices).view(n, 2, 3)
def object_to_image(z_where, obj):
n = obj.size(0)
theta = expand_z_where(z_where)
grid = affine_grid(theta, torch.Size((n, 1, 50, 50)))
out = grid_sample(obj.view(n, 1, 20, 20), grid)
return out.view(n, 50, 50)
```

A discussion of the details of the STN is beyond the scope of this
tutorial. For our purposes however, it suffices to keep in mind that
`object_to_image`

takes the small image generated by the neural
network and places it within a larger image with the desired pose.

Let’s visualize the results of calling `model_step_sketch`

a few times
to clarify this:

```
In [ ]:
```

```
pyro.set_rng_seed(0)
samples = [model_step_sketch(0)[0] for _ in range(5)]
show_images(samples)
```

### Generating an entire image¶

Having completed the implementation of a single step, we next consider how we can use this to generate an entire image. Recall that we would like to maintain uncertainty over the number of steps used to generate each data point. One choice we could make for the prior over the number of steps is the geometric distribution, which can be expressed as follows:

```
In [ ]:
```

```
pyro.set_rng_seed(0)
def geom(num_trials=0):
p = Variable(torch.Tensor([0.5]))
x = pyro.sample('x{}'.format(num_trials), dist.bernoulli, p)
if x.data[0] == 1:
return num_trials
else:
return geom(num_trials + 1)
# Generate some samples.
for _ in range(5):
print('sampled {}'.format(geom()))
```

```
sampled 8
sampled 2
sampled 0
sampled 0
sampled 1
```

This is a direct translation of the definition of the geometric
distribution as the number of failures before a success in a series of
Bernoulli trials. Here we express this as a recursive function that
passes around a counter representing the number of trials made,
`num_trials`

. This function samples from the Bernoulli and returns
`num_trials`

if `x == 1`

(which represents success), otherwise it
makes a recursive call, incrementing the counter.

The use of a geometric prior is appealing because it does not bound the
number of steps the model can use a priori. It’s also convenient,
because by extending `geometric`

to generate an object before each
recursive call, we turn this from a geometric distribution over counts
to a distribution over images with a geometrically distributed number of
steps.

```
In [ ]:
```

```
def geom_prior(x, step=0):
p = Variable(torch.Tensor([0.5]))
i = pyro.sample('i{}'.format(step), dist.bernoulli, p)
if i.data[0] == 1:
return x
else:
x = x + model_step_sketch(step)
return geom_prior(x, step + 1)
```

Let’s visualize some samples from this distribution:

```
In [ ]:
```

```
pyro.set_rng_seed(13)
x_empty = ng_zeros(1, 50, 50)
samples = [geom_prior(x_empty)[0] for _ in range(5)]
show_images(samples)
```

#### Aside: Vectorized mini-batches¶

In our final implementation we would like to generate a mini batch of
samples in parallel for efficiency. While Pyro supports vectorized mini
batches with `iarange`

, it currently assumes that each `sample`

statement within `iarange`

makes a choice for all samples in the mini
batch. This is problematic for us because as we have just seen, each
sample from our model can take a different number of steps, and hence
make a different number of choices.

One way around this is to arrange for all samples to take the same number of steps. Of course, we still want to have differing numbers of objects in the images we generate, so we will take the occurence of a successful Bernoulli trial to indicate that we should stop adding the objects we generate to the output image, and use some other criteria to decide when to stop making further steps.

Even though this approach performs redundant computation, the gains from using mini batches are so large that this is still a win overall. (Eventually though, we’d like to be able to express the model in a way that avoids this redundant computation.)

Following [1] we choose to take a fixed number of steps for each sample. (By doing so we no longer specify a geometric distribution over the number of steps, since the number of steps is now bounded. It would be interesting to explore the alternative of having each sample in the batch take steps until a successful Bernoulli trial has occured in each, as this would retain the geometric prior.)

Here’s an updated model step function that implements this idea. The
only changes from `model_step_sketch`

are that we now conditionally
add the object to the output image based on a value sampled from a
Bernoulli distribution, and we’ve added a new parameter `n`

that
specifies the size of the mini batch.

```
In [ ]:
```

```
def model_step(n, t, prev_x, prev_z_pres):
# Sample variable indicating whether to add this object to the output.
# We multiply the success probability of 0.5 by the value sampled for this
# choice in the previous step. By doing so we add objects to the output until
# the first 0 is sampled, after which we add no further objects.
z_pres = pyro.sample('z_pres_{}'.format(t), dist.bernoulli, 0.5 * prev_z_pres)
z_where = pyro.sample('z_where_{}'.format(t),
dist.normal,
z_where_prior_mu,
z_where_prior_sigma,
batch_size=n)
z_what = pyro.sample('z_what_{}'.format(t),
dist.normal,
z_what_prior_mu,
z_what_prior_sigma,
batch_size=n)
y_att = decode(z_what)
y = object_to_image(z_where, y_att)
# Combine the image generated at this step with the image so far.
x = prev_x + y * z_pres.view(-1, 1, 1)
return x, z_pres
```

By iterating this step function we can produce an entire image, composed of multiple objects. Since each image in the multi-mnist dataset contains zero, one or two digits we will allow the model to use up to (and including) three steps. This will allow us to observe whether inference avoids using the unnecessary final step, and to test the model’s ability to generalize to images with more digits than are present in the dataset.

```
In [ ]:
```

```
def prior(n):
x = ng_zeros(n, 50, 50)
z_pres = ng_ones(n, 1)
for t in range(3):
x, z_pres = model_step(n, t, x, z_pres)
return x
```

We have now fully specified the prior for our model. Let’s visualize some samples to get a feel for this distribution:

```
In [ ]:
```

```
pyro.set_rng_seed(87678)
show_images(prior(5))
```

#### Specifying the likelihood¶

The last thing we need in order to complete the specification of the
model is a likelihood function. Following [1] we will use a Gaussian
likelihood with a fixed standard deviation of 0.3. This is straight
forward to implement using `pyro.observe`

.

When we later come to perform inference we will find it convenient to
package the prior and likelihood into a single function. This is also a
convenient place to introduce `iarange`

, which we use to implement
data subsampling, and to register the networks we would like to optimize
with `pyro.module`

.

```
In [ ]:
```

```
def model(data):
# Register network for optimization.
pyro.module("decode", decode)
with pyro.iarange('data', data.size(0)) as indices:
batch = data[indices]
x = prior(batch.size(0)).view(-1, 50 * 50)
sd = (0.3 * ng_ones(1)).expand_as(x)
pyro.observe('obs', dist.normal, batch, x, sd)
```

## Guide¶

Following [1] we will perform amortized stochastic variational inference in this model. Pyro provides general purpose machinery that implements most of this inference strategy, but as we have seen in earlier tutorials we are required to provide a model specific guide. What we call a guide in Pyro is exactly the entity called the “inference network” in the paper.

We will structure the guide around a recurrent network to allow the guide to capture (some of) the dependencies we expect to be present in the true posterior. At each step the recurrent network will generate the parameters for the choices made within the step. The values sampled will be fed back into the recurrent network so that this information can be used when computing the parameters for the next step. The guide for the Deep Markov Model shares a similar structure.

As in the model, the core of the guide is the logic for a single step. Here’s a sketch of an implementation of this:

```
In [ ]:
```

```
def guide_step_basic(t, data, prev):
# The RNN takes the images and choices from the previous step as input.
rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
h, c = rnn(rnn_input, (prev.h, prev.c))
# Compute parameters for all choices made this step, by passing
# the RNN hidden start through another neural network.
z_pres_p, z_where_mu, z_where_sigma, z_what_mu, z_what_sigma = predict_basic(h)
z_pres = pyro.sample('z_pres_{}'.format(t),
dist.bernoulli, z_pres_p * prev.z_pres)
z_where = pyro.sample('z_where_{}'.format(t),
dist.normal, z_where_mu, z_where_sigma)
z_what = pyro.sample('z_what_{}'.format(t),
dist.normal, z_what_mu, z_what_sigma)
return # values for next step
```

This would be a reasonable guide to use with this model, but the paper describes a crucial improvement we can make to the code above. Recall that the guide will output information about an object’s pose and its latent code at each step. The improvement we can make is based on the observation that once we have inferred the pose of an object, we can do a better job of inferring its latent code if we use the pose information to crop the object from the input image, and pass the result (which we’ll call a “window”) through an additional network in order to compute the parameters of the latent code. We’ll call this additional network the “encoder” below.

Here’s how we can implement this improved guide, and a fleshed out implementation of the networks involved:

```
In [ ]:
```

```
rnn = nn.LSTMCell(2554, 256)
# Takes pixel intensities of the attention window to parameters (mean,
# standard deviation) of the distribution over the latent code,
# z_what.
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.l1 = nn.Linear(400, 200)
self.l2 = nn.Linear(200, 100)
def forward(self, data):
h = relu(self.l1(data))
a = self.l2(h)
return a[:, 0:50], softplus(a[:, 50:])
encode = Encoder()
# Takes the guide RNN hidden state to parameters of
# the guide distributions over z_where and z_pres.
class Predict(nn.Module):
def __init__(self, ):
super(Predict, self).__init__()
self.l = nn.Linear(256, 7)
def forward(self, h):
a = self.l(h)
z_pres_p = sigmoid(a[:, 0:1]) # Squish to [0,1]
z_where_mu = a[:, 1:4]
z_where_sigma = softplus(a[:, 4:]) # Squish to >0
return z_pres_p, z_where_mu, z_where_sigma
predict = Predict()
def guide_step_improved(t, data, prev):
rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
h, c = rnn(rnn_input, (prev.h, prev.c))
z_pres_p, z_where_mu, z_where_sigma = predict(h)
z_pres = pyro.sample('z_pres_{}'.format(t),
dist.bernoulli, z_pres_p * prev.z_pres)
z_where = pyro.sample('z_where_{}'.format(t),
dist.normal, z_where_mu, z_where_sigma)
# New. Crop a small window from the input.
x_att = image_to_object(z_where, data)
# Compute the parameter of the distribution over z_what
# by passing the window through the encoder network.
z_what_mu, z_what_sigma = encode(x_att)
z_what = pyro.sample('z_what_{}'.format(t),
dist.normal, z_what_mu, z_what_sigma)
return # values for next step
```

Since we would like to maintain differentiability of the guide we again
use a STN to perform the required “cropping”. The `image_to_object`

function performs the opposite transform to the object_to_image function
used in the guide. That is, the former takes a small image and places it
on a larger image, and the latter crops a small image from a larger
image.

```
In [ ]:
```

```
def z_where_inv(z_where):
# Take a batch of z_where vectors, and compute their "inverse".
# That is, for each row compute:
# [s,x,y] -> [1/s,-x/s,-y/s]
# These are the parameters required to perform the inverse of the
# spatial transform performed in the generative model.
n = z_where.size(0)
out = torch.cat((ng_ones([1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)
out = out / z_where[:, 0:1]
return out
def image_to_object(z_where, image):
n = image.size(0)
theta_inv = expand_z_where(z_where_inv(z_where))
grid = affine_grid(theta_inv, torch.Size((n, 1, 20, 20)))
out = grid_sample(image.view(n, 1, 50, 50), grid)
return out.view(n, -1)
```

### Another perspective¶

So far we’ve considered the model and the guide in isolation, but we gain an interesting perspective if we zoom out and look at the model and guide computation as a whole. Doing so, we see that at each step AIR includes a sub-computation that has the same structure as a Variational Auto-encoder (VAE).

To see this, notice that the guide passes the window through a neural network (the encoder) to generate the parameters of the distribution over a latent code, and the model passes samples from this latent code distribution through another neural network (the decoder) to generate an output window. This structure is highlighted in the following figure, reproduced from [1]:

From this perspective AIR is seen as a sequential variant of the VAE. The act of cropping a small window from the input image serves to restrict the attention of a VAE to a small region of the input image at each step; hence “Attend, Infer, Repeat”.

## Inference¶

As we mentioned in the introduction, successfully performing inference in this model is a challenge. In particular, the presence of discrete choices in the model makes inference trickier than in a model in which all choices can be reparameterized. The underlying problem we face is that the gradient estimates we use in the optimization performed by variational inference have much higher variance in the presence of discrete choices.

To bring this variance under control, the paper applies a technique called “data dependent baselines” (AKA “neural baselines”) to the discrete choices in the model.

### Data dependent baselines¶

Happily for us, Pyro includes support for data dependent baselines. If
you are not already familiar with this idea, you might want to read our
introduction before continuing. As model authors
we only have to implement the neural network, pass it our data as input,
and feed its output to `pyro.sample`

. Pyro’s inference back-end will
ensure that the baseline is included in the gradient estimator used for
inference, and that the network parameters are updated appropriately.

Let’s see how we can add data dependent baselines to our AIR implementation. We need a neural network that can output a (scalar) baseline value at each discrete choice in the guide, having received a multi-mnist image and the values sampled by the guide so far as input. Notice that this is very similar to the structure of the guide network, and indeed we will again use a recurrent network.

To implement this we will first write a short helper function that implements a single step of the RNN we’ve just described:

```
In [ ]:
```

```
bl_rnn = nn.LSTMCell(2554, 256)
bl_predict = nn.Linear(256, 1)
# Use an RNN to compute the baseline value. This network takes the
# input images and the values samples so far as input.
def baseline_step(x, prev):
rnn_input = torch.cat((x,
prev.z_where.detach(),
prev.z_what.detach(),
prev.z_pres.detach()), 1)
bl_h, bl_c = bl_rnn(rnn_input, (prev.bl_h, prev.bl_c))
bl_value = bl_predict(bl_h)
return bl_value, bl_h, bl_c
```

Notice that we `detach`

values sampled by the guide before passing
them to the baseline network. This is important as the baseline network
and the guide network are entirely separate networks optimized with
different objectives. Without this, gradients would flow from the
baseline network into the guide network. When using data dependent
baselines we must do this whenever we feed values sampled by the guide
into the baselines network. (If we don’t we’ll trigger a PyTorch
run-time error.)

We now have everything we need to complete the implementation of the
guide. Our final `guide_step`

function will be very similar to
`guide_step_improved`

introduced above. The only change is that we
will call the `baseline_step`

helper and pass the baseline value it
returns to `pyro.sample`

, completing the baseline implementation.
We’ll also write a `guide`

function that will iterate `guide_step`

in order to provide a guide for the whole model.

```
In [ ]:
```

```
GuideState = namedtuple('GuideState', ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what'])
def initial_guide_state(n):
return GuideState(h=ng_zeros(n, 256),
c=ng_zeros(n, 256),
bl_h=ng_zeros(n, 256),
bl_c=ng_zeros(n, 256),
z_pres=ng_ones(n, 1),
z_where=ng_zeros(n, 3),
z_what=ng_zeros(n, 50))
def guide_step(t, data, prev):
rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
h, c = rnn(rnn_input, (prev.h, prev.c))
z_pres_p, z_where_mu, z_where_sigma = predict(h)
# Here we compute the baseline value, and pass it to sample.
baseline_value, bl_h, bl_c = baseline_step(data, prev)
z_pres = pyro.sample('z_pres_{}'.format(t),
dist.bernoulli,
z_pres_p * prev.z_pres,
baseline=dict(baseline_value=baseline_value))
z_where = pyro.sample('z_where_{}'.format(t),
dist.normal, z_where_mu, z_where_sigma)
x_att = image_to_object(z_where, data)
z_what_mu, z_what_sigma = encode(x_att)
z_what = pyro.sample('z_what_{}'.format(t),
dist.normal, z_what_mu, z_what_sigma)
return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
def guide(data):
# Register networks for optimization.
pyro.module('rnn', rnn),
pyro.module('predict', predict),
pyro.module('encode', encode),
pyro.module('bl_rnn', bl_rnn)
pyro.module('bl_predict', bl_predict)
with pyro.iarange('data', data.size(0), subsample_size=64) as indices:
batch = data[indices]
state = initial_guide_state(batch.size(0))
steps = []
for t in range(3):
state = guide_step(t, batch, state)
steps.append(state)
return steps
```

### Putting it all together¶

We have now completed the implementation of the model and the guide. As we have already seen in earlier tutorials, we need write only a few more lines of code to begin performing inference:

```
In [ ]:
```

```
data = mnist.view(-1, 50 * 50)
svi = SVI(model,
guide,
optim.Adam({'lr': 1e-4}),
loss='ELBO',
trace_graph=True)
for i in range(5):
loss = svi.step(data)
print('i={}, elbo={:.2f}'.format(i, loss / data.size(0)))
```

One key detail here is that we pass the `trace_graph=True`

option to
`SVI`

. This enables a more sophisticated gradient
estimator (implicity used in [1]) that further
reduces the variance of gradient estimates by making use of independence
information included in the model. Use of this feature is necessary in
order to achieve good results in the presence of discrete choices.

## Improvements¶

Our standalone AIR implementation includes a few simple improvements to the basic recipe given in this tutorial:

- It is reported to be useful in practice to use a different learning
rate for the baseline network. In [1] a learning rate of
`1e-4`

was used for the guide network, and a learning rate of`1e-3`

was used for the baseline network. This is straight forward to implement in Pyro by tagging modules associated with the baseline network and passing multiple learning rates to the optimizer. (See the section on optimizers in part I of the SVI tutorial for more detail.) - Use of larger neural networks.
- Use of optimizable parameters for the initial guide state.

## Results¶

The following images show the progress made by our standalone
implementation
during an inference run. The top image shows four data points from the
training set. The bottom image is a visualization of a sample from the
guide (for these data points) that shows the values sampled for
`z_pres`

and `z_where`

. It also shows reconstructions of the input
obtained by passing the sample from the guide back through the model to
generate an output image.

Note that at this stage of inference the locations of most of the digits are inferred accurately but object counts are not.

## References¶

[1]
`Attend, Infer, Repeat: Fast Scene Understanding with Generative Models`

S. M. Ali Eslami and Nicolas Heess and Theophane Weber and Yuval
Tassa and Koray Kavukcuoglu and Geoffrey E. Hinton

[2] `Spatial Transformer Networks`

Max Jaderberg and Karen
Simonyan and Andrew Zisserman