Example: Epidemiological inference via HMC

This tutorial is in the form of a script (see below).

Example Usage

The following example runs the script to generate outbreak data for a population of 10000 observed for 60 days, then infer infection parameters and forecast new infections for another 30 days. This takes about 3 minutes on my laptop.

$ python examples/sir_hmc.py -p 10000 -d 60 -f 30 --plot
Generating data...
Observed 452/871 infections:
0 0 2 1 2 0 0 3 2 0 1 3 1 3 0 1 0 6 4 3 6 4 4 3 3 3 5 3 3 3 5 1 4 6 4 2 6 8 7 4 11 8 14 9 17 13 9 14 10 15 16 22 20 22 19 20 28 25 23 21
Running inference...
Sample: 100%|=========================| 300/300 [02:35,  1.93it/s, step size=9.67e-02, acc. prob=0.878]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        R0      1.40      0.07      1.40      1.28      1.49     26.56      1.06
       rho      0.47      0.02      0.47      0.44      0.52      7.08      1.22
  S_aux[0]   9998.74      0.64   9998.75   9997.84   9999.67     28.74      1.00
  S_aux[1]   9998.37      0.72   9998.38   9997.28   9999.44     52.24      1.02
  ...
  I_aux[0]      1.11      0.64      0.99      0.19      2.02     22.01      1.00
  I_aux[1]      1.55      0.74      1.65      0.05      2.47     10.05      1.10
  ...

Number of divergences: 0
R0: truth = 1.5, estimate = 1.4 ± 0.0654
rho: truth = 0.5, estimate = 0.475 ± 0.023
Forecast of new infections Posterior distribution over parameters Trace plot of potential energy

Tutorial Script

View sir_hmc.py on github

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Introduction
# ============
#
# This advanced Pyro tutorial demonstrates a number of inference and prediction
# tricks in the context of epidemiological models, specifically stochastic
# discrete time compartmental models with large discrete state spaces. This
# tutorial assumes the reader has completed all introductory tutorials and
# additionally the tutorials on enumeration and effect handlers (poutines):
# http://pyro.ai/examples/enumeration.html
# http://pyro.ai/examples/effect_handlers.html

import argparse
import logging
import math
import re
from collections import OrderedDict

import torch

import pyro
import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import convolve
from pyro.util import warn_if_nan

logging.basicConfig(format="%(message)s", level=logging.INFO)


# A Discrete SIR Model
# ====================
#
# Let's consider one of the simplest compartmental models: an SIR model
# https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model
# This models the dynamics of three groups within a population:
#
#   population = Susceptible + Infected + Recovered
#
# At each discrete time step, each infected person infects a random number of
# susceptible people, and then randomly may recover. We noisily observe the
# number of people newly infected at each time step, assuming an unknown false
# negative rate, but no false positives. Our eventual objective is to estimate
# global model parameters R0 (the basic reproduction number), tau (the expected
# recovery time), and rho (the mean response rate = 1 - false negative rate).
# Having estimated these we will then estimate latent time series and forecast
# forward.
#
# We'll start by defining a discrete_model that uses a helper global_model to
# sample global parameters.
#
# Note we need to use ExtendedBinomial rather than Binomial because the data
# may lie outside of the predicted support. For these values,
# Binomial.log_prob() will error, whereas ExtendedBinomial.log_prob() will
# return -inf.


def global_model(population):
    tau = args.recovery_time  # Assume this can be measured exactly.
    R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
    rho = pyro.sample("rho", dist.Uniform(0, 1))

    # Convert interpretable parameters to distribution parameters.
    rate_s = -R0 / (tau * population)
    prob_i = 1 / (1 + tau)

    return rate_s, prob_i, rho


def discrete_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sequentially sample time-local variables.
    S = torch.tensor(args.population - 1.0)
    I = torch.tensor(1.0)
    for t, datum in enumerate(data):
        S2I = pyro.sample("S2I_{}".format(t), dist.Binomial(S, -(rate_s * I).expm1()))
        I2R = pyro.sample("I2R_{}".format(t), dist.Binomial(I, prob_i))
        S = pyro.deterministic("S_{}".format(t), S - S2I)
        I = pyro.deterministic("I_{}".format(t), I + S2I - I2R)
        pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)


# We can use this model to simulate data. We'll use poutine.condition to pin
# parameter values and poutine.trace to record sample observations.


def generate_data(args):
    logging.info("Generating data...")
    params = {
        "R0": torch.tensor(args.basic_reproduction_number),
        "rho": torch.tensor(args.response_rate),
    }
    empty_data = [None] * (args.duration + args.forecast)

    # We'll retry until we get an actual outbreak.
    for attempt in range(100):
        with poutine.trace() as tr:
            with poutine.condition(data=params):
                discrete_model(args, empty_data)

        # Concatenate sequential time series into tensors.
        obs = torch.stack(
            [
                site["value"]
                for name, site in tr.trace.nodes.items()
                if re.match("obs_[0-9]+", name)
            ]
        )
        S2I = torch.stack(
            [
                site["value"]
                for name, site in tr.trace.nodes.items()
                if re.match("S2I_[0-9]+", name)
            ]
        )
        assert len(obs) == len(empty_data)

        obs_sum = int(obs[: args.duration].sum())
        S2I_sum = int(S2I[: args.duration].sum())
        if obs_sum >= args.min_observations:
            logging.info(
                "Observed {:d}/{:d} infections:\n{}".format(
                    obs_sum,
                    S2I_sum,
                    " ".join([str(int(x)) for x in obs[: args.duration]]),
                )
            )
            return {"S2I": S2I, "obs": obs}

    raise ValueError(
        "Failed to generate {} observations. Try increasing "
        "--population or decreasing --min-observations".format(args.min_observations)
    )


# Inference
# =========
#
# While the above discrete_model is easy to understand, its discrete latent
# variables pose a challenge for inference. One of the most popular inference
# strategies for such models is Sequential Monte Carlo. However since Pyro and
# PyTorch are stronger in gradient based vectorizable inference algorithms, we
# will instead pursue inference based on Hamiltonian Monte Carlo (HMC).
#
# Our general inference strategy will be to:
# 1. Introduce auxiliary variables to make the model Markov.
# 2. Introduce more auxiliary variables to create a discrete parameterization.
# 3. Marginalize out all remaining discrete latent variables.
# 4. Vectorize to enable parallel-scan temporal filtering.
#
# Let's consider reparameterizing in terms of the variables (S, I) rather than
# (S2I, I2R). Since these may lead to inconsistent states, we need to replace
# the Binomial transition factors (S2I, I2R) with ExtendedBinomial.
#
# The following model is equivalent to the discrete_model:


@config_enumerate
def reparameterized_discrete_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sequentially sample time-local variables.
    S_curr = torch.tensor(args.population - 1.0)
    I_curr = torch.tensor(1.0)
    for t, datum in enumerate(data):
        # Sample reparameterizing variables.
        # When reparameterizing to a factor graph, we ignored density via
        # .mask(False). Thus distributions are used only for initialization.
        S_prev, I_prev = S_curr, I_curr
        S_curr = pyro.sample(
            "S_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)
        )
        I_curr = pyro.sample(
            "I_{}".format(t), dist.Binomial(args.population, 0.5).mask(False)
        )

        # Now we reverse the computation.
        S2I = S_prev - S_curr
        I2R = I_prev - I_curr + S2I
        pyro.sample(
            "S2I_{}".format(t),
            dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()),
            obs=S2I,
        )
        pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R)
        pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)


# By reparameterizing, we have converted to coordinates that make the model
# Markov. We have also replaced dynamic integer_interval constraints with
# easier static integer_interval constraints (although we'll still need good
# initialization to avoid NANs). Since the discrete latent variables are
# bounded (by population size), we can enumerate out discrete latent variables
# and perform HMC inference over the global latents. However enumeration
# complexity is O(population^4), so this is only feasible for very small
# populations.
#
# Here is an inference approach using an MCMC sampler.


def infer_hmc_enum(args, data):
    model = reparameterized_discrete_model
    return _infer_hmc(args, data, model)


def _infer_hmc(args, data, model, init_values={}):
    logging.info("Running inference...")
    kernel = NUTS(
        model,
        full_mass=[("R0", "rho")],
        max_tree_depth=args.max_tree_depth,
        init_strategy=init_to_value(values=init_values),
        jit_compile=args.jit,
        ignore_jit_warnings=True,
    )

    # We'll define a hook_fn to log potential energy values during inference.
    # This is helpful to diagnose whether the chain is mixing.
    energies = []

    def hook_fn(kernel, *unused):
        e = float(kernel._potential_energy_last)
        energies.append(e)
        if args.verbose:
            logging.info("potential = {:0.6g}".format(e))

    mcmc = MCMC(
        kernel,
        hook_fn=hook_fn,
        num_samples=args.num_samples,
        warmup_steps=args.warmup_steps,
    )
    mcmc.run(args, data)
    mcmc.summary()
    if args.plot:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(6, 3))
        plt.plot(energies)
        plt.xlabel("MCMC step")
        plt.ylabel("potential energy")
        plt.title("MCMC energy trace")
        plt.tight_layout()

    samples = mcmc.get_samples()
    return samples


# To scale to large populations, we'll continue to reparameterize, this time
# replacing each of (S_aux,I_aux) with a combination of a bounded real
# variable and a Categorical variable with only four values.
#
# This is the crux: we can now perform HMC over the real variable and
# marginalize out the Categorical variables using variable elimination.
#
# We first define a helper to create enumerated Categorical sites.


def quantize(name, x_real, min, max):
    """
    Randomly quantize in a way that preserves probability mass.
    We use a piecewise polynomial spline of order 3.
    """
    assert min < max
    lb = x_real.detach().floor()

    # This cubic spline interpolates over the nearest four integers, ensuring
    # piecewise quadratic gradients.
    s = x_real - lb
    ss = s * s
    t = 1 - s
    tt = t * t
    probs = torch.stack(
        [
            t * tt,
            4 + ss * (3 * s - 6),
            4 + tt * (3 * t - 6),
            s * ss,
        ],
        dim=-1,
    ) * (1 / 6)
    q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real)

    x = lb + q - 1
    x = torch.max(x, 2 * min - 1 - x)
    x = torch.min(x, 2 * max + 1 - x)

    return pyro.deterministic(name, x)


# Now we can define another equivalent model.


@config_enumerate
def continuous_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sample reparameterizing variables.
    S_aux = pyro.sample(
        "S_aux",
        dist.Uniform(-0.5, args.population + 0.5)
        .mask(False)
        .expand(data.shape)
        .to_event(1),
    )
    I_aux = pyro.sample(
        "I_aux",
        dist.Uniform(-0.5, args.population + 0.5)
        .mask(False)
        .expand(data.shape)
        .to_event(1),
    )

    # Sequentially sample time-local variables.
    S_curr = torch.tensor(args.population - 1.0)
    I_curr = torch.tensor(1.0)
    for t, datum in poutine.markov(enumerate(data)):
        S_prev, I_prev = S_curr, I_curr
        S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population)
        I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population)

        # Now we reverse the computation.
        S2I = S_prev - S_curr
        I2R = I_prev - I_curr + S2I
        pyro.sample(
            "S2I_{}".format(t),
            dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()),
            obs=S2I,
        )
        pyro.sample("I2R_{}".format(t), dist.ExtendedBinomial(I_prev, prob_i), obs=I2R)
        pyro.sample("obs_{}".format(t), dist.ExtendedBinomial(S2I, rho), obs=datum)


# Now all latent variables in the continuous_model are either continuous or
# enumerated, so we can use HMC. However we need to take special care with
# constraints because the above Markov reparameterization covers regions of
# hypothesis space that are infeasible (i.e. whose log_prob is -infinity). We
# thus heuristically initialize to a feasible point.


def heuristic_init(args, data):
    """Heuristically initialize to a feasible point."""
    # Start with a single infection.
    S0 = args.population - 1
    # Assume 50% <= response rate <= 100%.
    S2I = data * min(2.0, (S0 / data.sum()).sqrt())
    S_aux = (S0 - S2I.cumsum(-1)).clamp(min=0.5)
    # Account for the single initial infection.
    S2I[0] += 1
    # Assume infection lasts less than a month.
    recovery = torch.arange(30.0).div(args.recovery_time).neg().exp()
    I_aux = convolve(S2I, recovery)[: len(data)].clamp(min=0.5)

    return {
        "R0": torch.tensor(2.0),
        "rho": torch.tensor(0.5),
        "S_aux": S_aux,
        "I_aux": I_aux,
    }


def infer_hmc_cont(model, args, data):
    init_values = heuristic_init(args, data)
    return _infer_hmc(args, data, model, init_values=init_values)


# Our final inference trick is to vectorize. We can repurpose DiscreteHMM's
# implementation here, but we'll need to manually represent a Markov
# neighborhood of multiple Categorical of size 4 as single joint Categorical
# with 4 * 4 = 16 states, and then manually perform variable elimination (the
# factors here don't quite conform to DiscreteHMM's interface).


def quantize_enumerate(x_real, min, max):
    """
    Randomly quantize in a way that preserves probability mass.
    We use a piecewise polynomial spline of order 3.
    """
    assert min < max
    lb = x_real.detach().floor()

    # This cubic spline interpolates over the nearest four integers, ensuring
    # piecewise quadratic gradients.
    s = x_real - lb
    ss = s * s
    t = 1 - s
    tt = t * t
    probs = torch.stack(
        [
            t * tt,
            4 + ss * (3 * s - 6),
            4 + tt * (3 * t - 6),
            s * ss,
        ],
        dim=-1,
    ) * (1 / 6)
    logits = safe_log(probs)
    q = torch.arange(-1.0, 3.0)

    x = lb.unsqueeze(-1) + q
    x = torch.max(x, 2 * min - 1 - x)
    x = torch.min(x, 2 * max + 1 - x)
    return x, logits


def vectorized_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sample reparameterizing variables.
    S_aux = pyro.sample(
        "S_aux",
        dist.Uniform(-0.5, args.population + 0.5)
        .mask(False)
        .expand(data.shape)
        .to_event(1),
    )
    I_aux = pyro.sample(
        "I_aux",
        dist.Uniform(-0.5, args.population + 0.5)
        .mask(False)
        .expand(data.shape)
        .to_event(1),
    )

    # Manually enumerate.
    S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population)
    I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population)
    # Truncate final value from the right then pad initial value onto the left.
    S_prev = torch.nn.functional.pad(
        S_curr[:-1], (0, 0, 1, 0), value=args.population - 1
    )
    I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1)
    # Reshape to support broadcasting, similar to EnumMessenger.
    T = len(data)
    Q = 4
    S_prev = S_prev.reshape(T, Q, 1, 1, 1)
    I_prev = I_prev.reshape(T, 1, Q, 1, 1)
    S_curr = S_curr.reshape(T, 1, 1, Q, 1)
    S_logp = S_logp.reshape(T, 1, 1, Q, 1)
    I_curr = I_curr.reshape(T, 1, 1, 1, Q)
    I_logp = I_logp.reshape(T, 1, 1, 1, Q)
    data = data.reshape(T, 1, 1, 1, 1)

    # Reverse the S2I,I2R computation.
    S2I = S_prev - S_curr
    I2R = I_prev - I_curr + S2I

    # Compute probability factors.
    S2I_logp = dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()).log_prob(S2I)
    I2R_logp = dist.ExtendedBinomial(I_prev, prob_i).log_prob(I2R)
    obs_logp = dist.ExtendedBinomial(S2I, rho).log_prob(data)

    # Manually perform variable elimination.
    logp = S_logp + (I_logp + obs_logp) + S2I_logp + I2R_logp
    logp = logp.reshape(-1, Q * Q, Q * Q)
    logp = pyro.distributions.hmm._sequential_logmatmulexp(logp)
    logp = logp.reshape(-1).logsumexp(0)
    logp = logp - math.log(4)  # Account for S,I initial distributions.
    warn_if_nan(logp)
    pyro.factor("obs", logp)


# We can fit vectorized_model exactly as we fit the original continuous_model,
# using our infer_hmc_cont helper. The vectorized model is more than an order
# of magnitude faster than the sequential version, and scales logarithmically
# in time (up to your machine's parallelism).
#
# After inference we have samples of all latent variables. Let's define a
# helper to examine the inferred posterior distributions.


def evaluate(args, samples):
    # Print estimated values.
    names = {"basic_reproduction_number": "R0", "response_rate": "rho"}
    for name, key in names.items():
        mean = samples[key].mean().item()
        std = samples[key].std().item()
        logging.info(
            "{}: truth = {:0.3g}, estimate = {:0.3g} \u00B1 {:0.3g}".format(
                key, getattr(args, name), mean, std
            )
        )

    # Optionally plot histograms.
    if args.plot:
        import matplotlib.pyplot as plt
        import seaborn as sns

        fig, axes = plt.subplots(2, 1, figsize=(5, 5))
        axes[0].set_title("Posterior parameter estimates")
        for ax, (name, key) in zip(axes, names.items()):
            truth = getattr(args, name)
            sns.distplot(samples[key], ax=ax, label="posterior")
            ax.axvline(truth, color="k", label="truth")
            ax.set_xlabel(key + " = " + name.replace("_", " "))
            ax.set_yticks(())
            ax.legend(loc="best")
        plt.tight_layout()


# Prediction and Forecasting
# ==========================
#
# So far we've written four models that each describe the same probability
# distribution. Each successive model made inference cheaper. Next let's move
# beyond inference and consider predicting latent infection rate and
# forecasting future infections.
#
# We'll use Pyro's effect handlers to combine multiple of the above models,
# leveraging the vectorized_model for inference, then the continuous_model to
# compute local latent variables, and finally the original discrete_model to
# forecast forward in time. Let's assume posterior samples have already been
# generated via infer_hmc_cont(vectorized_model, ...).


@torch.no_grad()
def predict(args, data, samples, truth=None):
    logging.info("Forecasting {} steps ahead...".format(args.forecast))
    particle_plate = pyro.plate("particles", args.num_samples, dim=-1)

    # First we sample discrete auxiliary variables from the continuous
    # variables sampled in vectorized_model. This samples only time steps
    # [0:duration]. Here infer_discrete runs a forward-filter backward-sample
    # algorithm. We'll add these new samples to the existing dict of samples.
    model = poutine.condition(continuous_model, samples)
    model = particle_plate(model)
    model = infer_discrete(model, first_available_dim=-2)
    with poutine.trace() as tr:
        model(args, data)
    samples = OrderedDict(
        (name, site["value"])
        for name, site in tr.trace.nodes.items()
        if site["type"] == "sample"
    )

    # Next we'll run the forward generative process in discrete_model. This
    # samples time steps [duration:duration+forecast]. Again we'll update the
    # dict of samples.
    extended_data = list(data) + [None] * args.forecast
    model = poutine.condition(discrete_model, samples)
    model = particle_plate(model)
    with poutine.trace() as tr:
        model(args, extended_data)
    samples = OrderedDict(
        (name, site["value"])
        for name, site in tr.trace.nodes.items()
        if site["type"] == "sample"
    )

    # Finally we'll concatenate the sequentially sampled values into contiguous
    # tensors. This operates on the entire time interval [0:duration+forecast].
    for key in ("S", "I", "S2I", "I2R"):
        pattern = key + "_[0-9]+"
        series = [value for name, value in samples.items() if re.match(pattern, name)]
        assert len(series) == args.duration + args.forecast
        series[0] = series[0].expand(series[1].shape)
        samples[key] = torch.stack(series, dim=-1)
    S2I = samples["S2I"]
    median = S2I.median(dim=0).values
    logging.info(
        "Median prediction of new infections (starting on day 0):\n{}".format(
            " ".join(map(str, map(int, median)))
        )
    )

    # Optionally plot the latent and forecasted series of new infections.
    if args.plot:
        import matplotlib.pyplot as plt

        plt.figure()
        time = torch.arange(args.duration + args.forecast)
        p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
        p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
        plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
        plt.plot(time, median, "r-", label="median")
        plt.plot(time[: args.duration], data, "k.", label="observed")
        if truth is not None:
            plt.plot(time, truth, "k--", label="truth")
        plt.axvline(args.duration - 0.5, color="gray", lw=1)
        plt.xlim(0, len(time) - 1)
        plt.ylim(0, None)
        plt.xlabel("day after first infection")
        plt.ylabel("new infections per day")
        plt.title("New infections in population of {}".format(args.population))
        plt.legend(loc="upper left")
        plt.tight_layout()

    return samples


# Experiments
# ===========
#
# Finally we'll define an experiment runner. For example we can simulate 60
# days of infection on a population of 10000 and forecast forward another 30
# days, and plot the results as follows (takes about 3 minutes on my laptop):
#
# python sir_hmc.py -p 10000 -d 60 -f 30 --plot


def main(args):
    pyro.set_rng_seed(args.rng_seed)

    dataset = generate_data(args)
    obs = dataset["obs"][: args.duration]

    # Choose among inference methods.
    if args.enum:
        samples = infer_hmc_enum(args, obs)
    elif args.sequential:
        samples = infer_hmc_cont(continuous_model, args, obs)
    else:
        samples = infer_hmc_cont(vectorized_model, args, obs)

    # Evaluate fit.
    evaluate(args, samples)

    # Predict latent time series.
    if args.forecast:
        samples = predict(args, obs, samples, truth=dataset["S2I"])

    return samples


if __name__ == "__main__":
    assert pyro.__version__.startswith("1.9.0")
    parser = argparse.ArgumentParser(description="SIR epidemiology modeling using HMC")
    parser.add_argument("-p", "--population", default=10, type=int)
    parser.add_argument("-m", "--min-observations", default=3, type=int)
    parser.add_argument("-d", "--duration", default=10, type=int)
    parser.add_argument("-f", "--forecast", default=0, type=int)
    parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
    parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
    parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
    parser.add_argument(
        "-e", "--enum", action="store_true", help="use the full enumeration model"
    )
    parser.add_argument(
        "-s",
        "--sequential",
        action="store_true",
        help="use the sequential continuous model",
    )
    parser.add_argument("-n", "--num-samples", default=200, type=int)
    parser.add_argument("-w", "--warmup-steps", default=100, type=int)
    parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
    parser.add_argument("-r", "--rng-seed", default=0, type=int)
    parser.add_argument("--double", action="store_true")
    parser.add_argument("--jit", action="store_true")
    parser.add_argument("--cuda", action="store_true")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--plot", action="store_true")
    args = parser.parse_args()

    if args.double:
        torch.set_default_dtype(torch.float64)
    if args.cuda:
        torch.set_default_device("cuda")

    main(args)

    if args.plot:
        import matplotlib.pyplot as plt

        plt.show()