Example: Discrete Factor Graph Inference with Plated EinsumΒΆ

View einsum.py on github

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

"""
This example demonstrates how to use plated ``einsum`` with different backends
to compute logprob, gradients, MAP estimates, posterior samples, and marginals.

The interface for adjoint algorithms requires four steps:

1. Call ``require_backward()`` on all inputs.
2. Call ``x, = einsum(..., backend=...)`` with a nonstandard backend.
3. Call ``x._pyro_backward()` on the einsum output.
4. Retrieve results from ``._pyro_backward_result`` attributes of the inputs.

The results of these computations are returned, but this script does not
make use of them; instead we simply time the operations for profiling.
All profiling is done on jit-compiled functions. We exclude jit compilation
time from profiling results, assuming this can be done once.

You can measure complexity of different einsum problems by specifying
``--equation`` and ``--plates``.
"""

import argparse
import timeit

import torch
from torch.autograd import grad

from pyro.ops.contract import einsum
from pyro.ops.einsum.adjoint import require_backward
from pyro.util import ignore_jit_warnings

# We will cache jit-compiled versions of each function.
_CACHE = {}


def jit_prob(equation, *operands, **kwargs):
    """
    Runs einsum to compute the partition function.

    This is cheap but less numerically stable than using the torch_log backend.
    """
    key = "prob", equation, kwargs["plates"]
    if key not in _CACHE:
        # This simply wraps einsum for jit compilation.
        def _einsum(*operands):
            return einsum(equation, *operands, **kwargs)

        _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

    return _CACHE[key](*operands)


def jit_logprob(equation, *operands, **kwargs):
    """
    Runs einsum to compute the log partition function.

    This simulates evaluating an undirected graphical model.
    """
    key = "logprob", equation, kwargs["plates"]
    if key not in _CACHE:
        # This simply wraps einsum for jit compilation.
        def _einsum(*operands):
            return einsum(
                equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs
            )

        _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

    return _CACHE[key](*operands)


def jit_gradient(equation, *operands, **kwargs):
    """
    Runs einsum and calls backward on the partition function.

    This is simulates training an undirected graphical model.
    """
    key = "gradient", equation, kwargs["plates"]
    if key not in _CACHE:
        # This wraps einsum for jit compilation, but we will call backward on the result.
        def _einsum(*operands):
            return einsum(
                equation, *operands, backend="pyro.ops.einsum.torch_log", **kwargs
            )

        _CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)

    # Run forward pass.
    losses = _CACHE[key](*operands)

    # Work around PyTorch 1.0.0 bug https://github.com/pytorch/pytorch/issues/14875
    # whereby tuples of length 1 are unwrapped by the jit.
    if not isinstance(losses, tuple):
        losses = (losses,)

    # Run backward pass.
    grads = tuple(
        grad(loss, operands, retain_graph=True, allow_unused=True) for loss in losses
    )
    return grads


def _jit_adjoint(equation, *operands, **kwargs):
    """
    Runs einsum in forward-backward mode using ``pyro.ops.adjoint``.

    This simulates serving predictions from an undirected graphical model.
    """
    backend = kwargs.pop("backend", "pyro.ops.einsum.torch_sample")
    key = backend, equation, tuple(x.shape for x in operands), kwargs["plates"]
    if key not in _CACHE:
        # This wraps a complete adjoint algorithm call.
        @ignore_jit_warnings()
        def _forward_backward(*operands):
            # First we request backward results on each input operand.
            # This is the pyro.ops.adjoint equivalent of torch's .requires_grad_().
            for operand in operands:
                require_backward(operand)

            # Next we run the forward pass.
            results = einsum(equation, *operands, backend=backend, **kwargs)

            # The we run a backward pass.
            for result in results:
                result._pyro_backward()

            # Finally we retrieve results from the ._pyro_backward_result attribute
            # that has been set on each input operand. If you only want results on a
            # subset of operands, you can call require_backward() on only those.
            results = []
            for x in operands:
                results.append(x._pyro_backward_result)
                x._pyro_backward_result = None

            return tuple(results)

        _CACHE[key] = torch.jit.trace(_forward_backward, operands, check_trace=False)

    return _CACHE[key](*operands)


def jit_map(equation, *operands, **kwargs):
    return _jit_adjoint(
        equation, *operands, backend="pyro.ops.einsum.torch_map", **kwargs
    )


def jit_sample(equation, *operands, **kwargs):
    return _jit_adjoint(
        equation, *operands, backend="pyro.ops.einsum.torch_sample", **kwargs
    )


def jit_marginal(equation, *operands, **kwargs):
    return _jit_adjoint(
        equation, *operands, backend="pyro.ops.einsum.torch_marginal", **kwargs
    )


def time_fn(fn, equation, *operands, **kwargs):
    iters = kwargs.pop("iters")
    _CACHE.clear()  # Avoid memory leaks.
    fn(equation, *operands, **kwargs)

    time_start = timeit.default_timer()
    for i in range(iters):
        fn(equation, *operands, **kwargs)
    time_end = timeit.default_timer()

    return (time_end - time_start) / iters


def main(args):
    if args.cuda:
        torch.set_default_device("cuda")

    if args.method == "all":
        for method in ["prob", "logprob", "gradient", "marginal", "map", "sample"]:
            args.method = method
            main(args)
        return

    print("Plate size  Time per iteration of {} (ms)".format(args.method))
    fn = globals()["jit_{}".format(args.method)]
    equation = args.equation
    plates = args.plates
    inputs, outputs = equation.split("->")
    inputs = inputs.split(",")

    # Vary all plate sizes at the same time.
    for plate_size in range(8, 1 + args.max_plate_size, 8):
        operands = []
        for dims in inputs:
            shape = torch.Size(
                [plate_size if d in plates else args.dim_size for d in dims]
            )
            operands.append((torch.empty(shape).uniform_() + 0.5).requires_grad_())

        time = time_fn(
            fn, equation, *operands, plates=plates, modulo_total=True, iters=args.iters
        )
        print(
            "{: <11s} {:0.4g}".format(
                "{} ** {}".format(plate_size, len(args.plates)), time * 1e3
            )
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="plated einsum profiler")
    parser.add_argument("-e", "--equation", default="a,abi,bcij,adj,deij->")
    parser.add_argument("-p", "--plates", default="ij")
    parser.add_argument("-d", "--dim-size", default=32, type=int)
    parser.add_argument("-s", "--max-plate-size", default=32, type=int)
    parser.add_argument("-n", "--iters", default=10, type=int)
    parser.add_argument("--cuda", action="store_true")
    parser.add_argument(
        "-m",
        "--method",
        default="all",
        help="one of: prob, logprob, gradient, marginal, map, sample",
    )
    args = parser.parse_args()
    main(args)