Example: Regional epidemiological modelsΒΆ

View regional.py on github

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse
import logging

import torch

import pyro
from pyro.contrib.epidemiology.models import RegionalSIRModel

logging.basicConfig(format="%(message)s", level=logging.INFO)


def Model(args, data):
    assert 0 <= args.coupling <= 1, args.coupling
    population = torch.full((args.num_regions,), float(args.population))
    coupling = torch.eye(args.num_regions).clamp(min=args.coupling)
    return RegionalSIRModel(population, coupling, args.recovery_time, data)


def generate_data(args):
    extended_data = [None] * (args.duration + args.forecast)
    model = Model(args, extended_data)
    logging.info("Simulating from a {}".format(type(model).__name__))
    for attempt in range(100):
        samples = model.generate(
            {
                "R0": args.basic_reproduction_number,
                "rho_c1": 10 * args.response_rate,
                "rho_c0": 10 * (1 - args.response_rate),
            }
        )
        obs = samples["obs"][: args.duration]
        S2I = samples["S2I"]

        obs_sum = int(obs.sum())
        S2I_sum = int(S2I[: args.duration].sum())
        if obs_sum >= args.min_observations:
            logging.info(
                "Observed {:d}/{:d} infections:\n{}".format(
                    obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0])
                )
            )
            return {"S2I": S2I, "obs": obs}

    raise ValueError(
        "Failed to generate {} observations. Try increasing "
        "--population or decreasing --min-observations".format(args.min_observations)
    )


def infer_mcmc(args, model):
    energies = []

    def hook_fn(kernel, *unused):
        e = float(kernel._potential_energy_last)
        energies.append(e)
        if args.verbose:
            logging.info("potential = {:0.6g}".format(e))

    mcmc = model.fit_mcmc(
        heuristic_num_particles=args.smc_particles,
        heuristic_ess_threshold=args.ess_threshold,
        warmup_steps=args.warmup_steps,
        num_samples=args.num_samples,
        max_tree_depth=args.max_tree_depth,
        num_quant_bins=args.num_bins,
        haar=args.haar,
        haar_full_mass=args.haar_full_mass,
        jit_compile=args.jit,
        hook_fn=hook_fn,
    )

    mcmc.summary()
    if args.plot:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(6, 3))
        plt.plot(energies)
        plt.xlabel("MCMC step")
        plt.ylabel("potential energy")
        plt.title("MCMC energy trace")
        plt.tight_layout()


def infer_svi(args, model):
    losses = model.fit_svi(
        heuristic_num_particles=args.smc_particles,
        heuristic_ess_threshold=args.ess_threshold,
        num_samples=args.num_samples,
        num_steps=args.svi_steps,
        num_particles=args.svi_particles,
        haar=args.haar,
        jit=args.jit,
    )

    if args.plot:
        import matplotlib.pyplot as plt

        plt.figure(figsize=(6, 3))
        plt.plot(losses)
        plt.xlabel("SVI step")
        plt.ylabel("loss")
        plt.title("SVI Convergence")
        plt.tight_layout()


def predict(args, model, truth):
    samples = model.predict(forecast=args.forecast)
    S2I = samples["S2I"]
    median = S2I.median(dim=0).values
    lines = ["Median prediction of new infections (starting on day 0):"]
    for r in range(args.num_regions):
        lines.append(
            "Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r]))))
        )
    logging.info("\n".join(lines))

    # Optionally plot the latent and forecasted series of new infections.
    if args.plot:
        import matplotlib.pyplot as plt

        fig, axes = plt.subplots(
            args.num_regions, sharex=True, figsize=(6, 1 + args.num_regions)
        )
        time = torch.arange(args.duration + args.forecast)
        p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
        p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
        for r, ax in enumerate(axes):
            ax.fill_between(
                time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI"
            )
            ax.plot(time, median[:, r], "r-", label="median")
            ax.plot(time[: args.duration], model.data[:, r], "k.", label="observed")
            ax.plot(time, truth[:, r], "k--", label="truth")
            ax.axvline(args.duration - 0.5, color="gray", lw=1)
            ax.set_xlim(0, len(time) - 1)
            ax.set_ylim(0, None)
        axes[0].set_title(
            "New infections among {} regions each of size {}".format(
                args.num_regions, args.population
            )
        )
        axes[args.num_regions // 2].set_ylabel("inf./day")
        axes[-1].set_xlabel("day after first infection")
        axes[-1].legend(loc="upper left")
        plt.tight_layout()
        plt.subplots_adjust(hspace=0)


def main(args):
    pyro.set_rng_seed(args.rng_seed)

    # Generate data.
    dataset = generate_data(args)
    obs = dataset["obs"]

    # Run inference.
    model = Model(args, obs)
    infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer]
    infer(args, model)

    # Predict latent time series.
    predict(args, model, truth=dataset["S2I"])


if __name__ == "__main__":
    assert pyro.__version__.startswith("1.9.1")
    parser = argparse.ArgumentParser(
        description="Regional compartmental epidemiology modeling using HMC"
    )
    parser.add_argument("-p", "--population", default=1000, type=int)
    parser.add_argument("-r", "--num-regions", default=2, type=int)
    parser.add_argument("-c", "--coupling", default=0.1, type=float)
    parser.add_argument("-m", "--min-observations", default=3, type=int)
    parser.add_argument("-d", "--duration", default=20, type=int)
    parser.add_argument("-f", "--forecast", default=10, type=int)
    parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
    parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
    parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
    parser.add_argument("--infer", default="mcmc")
    parser.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer")
    parser.add_argument("--svi", action="store_const", const="svi", dest="infer")
    parser.add_argument("--haar", action="store_true")
    parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
    parser.add_argument("-n", "--num-samples", default=200, type=int)
    parser.add_argument("-np", "--smc-particles", default=1024, type=int)
    parser.add_argument("-ss", "--svi-steps", default=5000, type=int)
    parser.add_argument("-sp", "--svi-particles", default=32, type=int)
    parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float)
    parser.add_argument("-w", "--warmup-steps", type=int)
    parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
    parser.add_argument("-nb", "--num-bins", default=1, type=int)
    parser.add_argument("--double", action="store_true", default=True)
    parser.add_argument("--single", action="store_false", dest="double")
    parser.add_argument("--rng-seed", default=0, type=int)
    parser.add_argument("--cuda", action="store_true")
    parser.add_argument("--jit", action="store_true", default=True)
    parser.add_argument("--nojit", action="store_false", dest="jit")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--plot", action="store_true")
    args = parser.parse_args()

    if args.warmup_steps is None:
        args.warmup_steps = args.num_samples
    if args.double:
        torch.set_default_dtype(torch.float64)
    if args.cuda:
        torch.set_default_device("cuda")

    main(args)

    if args.plot:
        import matplotlib.pyplot as plt

        plt.show()