# Example: Discrete Factor Graph Inference with Plated Einsum¶

# Copyright (c) 2017-2019 Uber Technologies, Inc.

"""
This example demonstrates how to use plated einsum with different backends
to compute logprob, gradients, MAP estimates, posterior samples, and marginals.

The interface for adjoint algorithms requires four steps:

1. Call require_backward() on all inputs.
2. Call x, = einsum(..., backend=...) with a nonstandard backend.
3. Call x._pyro_backward() on the einsum output.
4. Retrieve results from ._pyro_backward_result attributes of the inputs.

The results of these computations are returned, but this script does not
make use of them; instead we simply time the operations for profiling.
All profiling is done on jit-compiled functions. We exclude jit compilation
time from profiling results, assuming this can be done once.

You can measure complexity of different einsum problems by specifying
--equation and --plates.
"""

import argparse
import timeit

import torch

from pyro.ops.contract import einsum
from pyro.util import ignore_jit_warnings

# We will cache jit-compiled versions of each function.
_CACHE = {}

def jit_prob(equation, *operands, **kwargs):
"""
Runs einsum to compute the partition function.

This is cheap but less numerically stable than using the torch_log backend.
"""
key = "prob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(equation, *operands, **kwargs)

_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

return _CACHE[key](*operands)

def jit_logprob(equation, *operands, **kwargs):
"""
Runs einsum to compute the log partition function.

This simulates evaluating an undirected graphical model.
"""
key = "logprob", equation, kwargs["plates"]
if key not in _CACHE:

# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(
equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs
)

_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

return _CACHE[key](*operands)

"""
Runs einsum and calls backward on the partition function.

This is simulates training an undirected graphical model.
"""
if key not in _CACHE:

# This wraps einsum for jit compilation, but we will call backward on the result.
def _einsum(*operands):
return einsum(
equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs
)

_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

# Run forward pass.
losses = _CACHE[key](*operands)

# Work around PyTorch 1.0.0 bug https://github.com/pytorch/pytorch/issues/14875
# whereby tuples of length 1 are unwrapped by the jit.
if not isinstance(losses, tuple):
losses = (losses,)

# Run backward pass.
grad(loss, operands, retain_graph=True, allow_unused=True) for loss in losses
)

"""
Runs einsum in forward-backward mode using pyro.ops.adjoint.

This simulates serving predictions from an undirected graphical model.
"""
backend = kwargs.pop("backend", "pyro.ops.einsum.torch_sample")
key = backend, equation, tuple(x.shape for x in operands), kwargs["plates"]
if key not in _CACHE:

# This wraps a complete adjoint algorithm call.
@ignore_jit_warnings()
def _forward_backward(*operands):
# First we request backward results on each input operand.
for operand in operands:
require_backward(operand)

# Next we run the forward pass.
results = einsum(equation, *operands, backend=backend, **kwargs)

# The we run a backward pass.
for result in results:
result._pyro_backward()

# Finally we retrieve results from the ._pyro_backward_result attribute
# that has been set on each input operand. If you only want results on a
# subset of operands, you can call require_backward() on only those.
results = []
for x in operands:
results.append(x._pyro_backward_result)
x._pyro_backward_result = None

return tuple(results)

_CACHE[key] = torch.jit.trace(_forward_backward, operands, check_trace=False)

return _CACHE[key](*operands)

def jit_map(equation, *operands, **kwargs):
equation, *operands, backend="pyro.ops.einsum.torch_map", **kwargs
)

def jit_sample(equation, *operands, **kwargs):
equation, *operands, backend="pyro.ops.einsum.torch_sample", **kwargs
)

def jit_marginal(equation, *operands, **kwargs):
equation, *operands, backend="pyro.ops.einsum.torch_marginal", **kwargs
)

def time_fn(fn, equation, *operands, **kwargs):
iters = kwargs.pop("iters")
_CACHE.clear()  # Avoid memory leaks.
fn(equation, *operands, **kwargs)

time_start = timeit.default_timer()
for i in range(iters):
fn(equation, *operands, **kwargs)
time_end = timeit.default_timer()

return (time_end - time_start) / iters

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")

if args.method == "all":
for method in ["prob", "logprob", "gradient", "marginal", "map", "sample"]:
args.method = method
main(args)
return

print("Plate size  Time per iteration of {} (ms)".format(args.method))
fn = globals()["jit_{}".format(args.method)]
equation = args.equation
plates = args.plates
inputs, outputs = equation.split("->")
inputs = inputs.split(",")

# Vary all plate sizes at the same time.
for plate_size in range(8, 1 + args.max_plate_size, 8):
operands = []
for dims in inputs:
shape = torch.Size(
[plate_size if d in plates else args.dim_size for d in dims]
)

time = time_fn(
fn, equation, *operands, plates=plates, modulo_total=True, iters=args.iters
)
print(
"{: <11s} {:0.4g}".format(
"{} ** {}".format(plate_size, len(args.plates)), time * 1e3
)
)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="plated einsum profiler")
