# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import argparse
import logging
import torch
import pyro
from pyro.contrib.epidemiology.models import RegionalSIRModel
logging.basicConfig(format="%(message)s", level=logging.INFO)
def Model(args, data):
assert 0 <= args.coupling <= 1, args.coupling
population = torch.full((args.num_regions,), float(args.population))
coupling = torch.eye(args.num_regions).clamp(min=args.coupling)
return RegionalSIRModel(population, coupling, args.recovery_time, data)
def generate_data(args):
extended_data = [None] * (args.duration + args.forecast)
model = Model(args, extended_data)
logging.info("Simulating from a {}".format(type(model).__name__))
for attempt in range(100):
samples = model.generate(
{
"R0": args.basic_reproduction_number,
"rho_c1": 10 * args.response_rate,
"rho_c0": 10 * (1 - args.response_rate),
}
)
obs = samples["obs"][: args.duration]
S2I = samples["S2I"]
obs_sum = int(obs.sum())
S2I_sum = int(S2I[: args.duration].sum())
if obs_sum >= args.min_observations:
logging.info(
"Observed {:d}/{:d} infections:\n{}".format(
obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0])
)
)
return {"S2I": S2I, "obs": obs}
raise ValueError(
"Failed to generate {} observations. Try increasing "
"--population or decreasing --min-observations".format(args.min_observations)
)
def infer_mcmc(args, model):
energies = []
def hook_fn(kernel, *unused):
e = float(kernel._potential_energy_last)
energies.append(e)
if args.verbose:
logging.info("potential = {:0.6g}".format(e))
mcmc = model.fit_mcmc(
heuristic_num_particles=args.smc_particles,
heuristic_ess_threshold=args.ess_threshold,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
max_tree_depth=args.max_tree_depth,
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
jit_compile=args.jit,
hook_fn=hook_fn,
)
mcmc.summary()
if args.plot:
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(energies)
plt.xlabel("MCMC step")
plt.ylabel("potential energy")
plt.title("MCMC energy trace")
plt.tight_layout()
def infer_svi(args, model):
losses = model.fit_svi(
heuristic_num_particles=args.smc_particles,
heuristic_ess_threshold=args.ess_threshold,
num_samples=args.num_samples,
num_steps=args.svi_steps,
num_particles=args.svi_particles,
haar=args.haar,
jit=args.jit,
)
if args.plot:
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("loss")
plt.title("SVI Convergence")
plt.tight_layout()
def predict(args, model, truth):
samples = model.predict(forecast=args.forecast)
S2I = samples["S2I"]
median = S2I.median(dim=0).values
lines = ["Median prediction of new infections (starting on day 0):"]
for r in range(args.num_regions):
lines.append(
"Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r]))))
)
logging.info("\n".join(lines))
# Optionally plot the latent and forecasted series of new infections.
if args.plot:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(
args.num_regions, sharex=True, figsize=(6, 1 + args.num_regions)
)
time = torch.arange(args.duration + args.forecast)
p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
for r, ax in enumerate(axes):
ax.fill_between(
time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI"
)
ax.plot(time, median[:, r], "r-", label="median")
ax.plot(time[: args.duration], model.data[:, r], "k.", label="observed")
ax.plot(time, truth[:, r], "k--", label="truth")
ax.axvline(args.duration - 0.5, color="gray", lw=1)
ax.set_xlim(0, len(time) - 1)
ax.set_ylim(0, None)
axes[0].set_title(
"New infections among {} regions each of size {}".format(
args.num_regions, args.population
)
)
axes[args.num_regions // 2].set_ylabel("inf./day")
axes[-1].set_xlabel("day after first infection")
axes[-1].legend(loc="upper left")
plt.tight_layout()
plt.subplots_adjust(hspace=0)
def main(args):
pyro.set_rng_seed(args.rng_seed)
# Generate data.
dataset = generate_data(args)
obs = dataset["obs"]
# Run inference.
model = Model(args, obs)
infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer]
infer(args, model)
# Predict latent time series.
predict(args, model, truth=dataset["S2I"])
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(
description="Regional compartmental epidemiology modeling using HMC"
)
parser.add_argument("-p", "--population", default=1000, type=int)
parser.add_argument("-r", "--num-regions", default=2, type=int)
parser.add_argument("-c", "--coupling", default=0.1, type=float)
parser.add_argument("-m", "--min-observations", default=3, type=int)
parser.add_argument("-d", "--duration", default=20, type=int)
parser.add_argument("-f", "--forecast", default=10, type=int)
parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("--infer", default="mcmc")
parser.add_argument("--mcmc", action="store_const", const="mcmc", dest="infer")
parser.add_argument("--svi", action="store_const", const="svi", dest="infer")
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-np", "--smc-particles", default=1024, type=int)
parser.add_argument("-ss", "--svi-steps", default=5000, type=int)
parser.add_argument("-sp", "--svi-particles", default=32, type=int)
parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float)
parser.add_argument("-w", "--warmup-steps", type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-nb", "--num-bins", default=1, type=int)
parser.add_argument("--double", action="store_true", default=True)
parser.add_argument("--single", action="store_false", dest="double")
parser.add_argument("--rng-seed", default=0, type=int)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--jit", action="store_true", default=True)
parser.add_argument("--nojit", action="store_false", dest="jit")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
if args.warmup_steps is None:
args.warmup_steps = args.num_samples
if args.double:
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")
main(args)
if args.plot:
import matplotlib.pyplot as plt
plt.show()