Epidemiological models: Introduction

This tutorial introduces the pyro.contrib.epidemiology module, an epidemiological modeling language with a number of black box inference algorithms. This tutorial assumes the reader is already familiar with modeling, inference, and distribution shapes.

See also the following scripts:


  • To create a new model, inherit from the CompartmentalModel base class.

  • Override methods .global_model(), .initialize(params), and .transition(params, state, t).

  • Take care to support broadcasting and vectorized interpretation in those methods.

  • For single time series, set population to an integer.

  • For batched time series, let population be a vector, and use self.region_plate.

  • For models with complex inter-compartment flows, override the .compute_flows() method.

  • Flows with loops (undirected or directed) are not currently supported.

  • To perform cheap approximate inference via SVI, call the .fit_svi() method.

  • To perform more expensive inference via MCMC, call the .fit_mcmc() method.

  • To stochastically predict latent and future variables, call the .predict() method.

Table of contents

import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist

%matplotlib inline
assert pyro.__version__.startswith('1.9.1')
torch.set_default_dtype(torch.double)  # Required for MCMC inference.
smoke_test = ('CI' in os.environ)

Basic workflow

The pyro.contrib.epidemiology module provides a modeling language for a class of stochastic discrete-time discrete-count compartmental models, together with a number of black box inference algorithms to perform joint inference on global parameters and latent variables. This modeling language is more restrictive than the full Pyro probabilistic programming language:

  • control flow must be static;

  • compartmental distributions are restricted to binomial_dist(), beta_binomial_dist(), and infection_dist();

  • plates are not allowed, except for the single optional .region_plate;

  • all random variables must be either global or Markov and sampled at every time step, so e.g. time-windowed random variables are not supported;

  • models must support broadcasting and vectorization of time t.

These restrictions allow inference algorithms to vectorize over the time dimension, leading to inference algorithms with per-iteration parallel complexity sublinear in length of the time axis. The restriction on distributions allows inference algorithms to approximate parts of the model as Gaussian via moment matching, further speeding up inference. Finally, because real data is so often overdispersed relative to Binomial idealizations, the three distribution helpers provide an overdispersion parameter calibrated so that in the large-population limit all distribution helpers converge to log-normal.

Black box inference algorithms currently include: SVI with a moment-matching approximation, and NUTS either with a moment-matched approximation or with an exact auxiliary variable method detailed in the SIR HMC tutorial. All three algorithms initialize using SMC and reparameterize time dependent variables using a fast Haar wavelet transform. Default inference parameters are set for cheap approximate results; accurate results will require more steps and ideally comparison among different inference algorithms. We recommend that, when running MCMC inference, you use multiple chains, thus making it easier to diagnose mixing issues.

While MCMC inference can be more accurate for a given model, SVI is much faster and thus allows richer model structure (e.g. incorporating neural networks) and more rapid model iteration. We recommend starting model exploration using mean field SVI (via .fit_svi(guide_rank=0)), then optionally increasing accuracy using a low-rank multivariate normal guide (via .fit_svi(guide_rank=None)). For even more accurate posteriors you could then try moment-matched MCMC (via .fit_mcmc(num_quant_bins=1)), or the most accurate and most expensive enumerated MCMC (via .fit_mcmc(num_quant_bins=4)). We recommend that, when fitting models with neural networks, you train via .fit_svi(), then freeze the network (say by omitting a pyro.module() statement) before optionally running MCMC inference.


The pyro.contrib.epidemiology.models module provides a number of example models. While in principle these are reusable, we recommend forking and modifying these models for your task. Let’s take a look at one of the simplest examples, SimpleSIRModel. This model derives from the CompartmentalModel base class and overrides the three standard methods using familiar Pyro modeling code in each method.

  • .global_model() samples global parameters and packs them into a single return value (here a tuple, but any structure is allowed). The return value is available as the params argument to the other two methods.

  • .initialize(params) samples (or deterministically sets) initial values of time series, returning a dictionary mapping time series name to initial value.

  • .transition(params, state, t) inputs global params, the state at the previous time step, and the time index t (which may be a slice!). It then samples flows and updates the state dict.

class SimpleSIRModel(CompartmentalModel):
    def __init__(self, population, recovery_time, data):
        compartments = ("S", "I")  # R is implicit.
        duration = len(data)
        super().__init__(compartments, duration, population)
        assert isinstance(recovery_time, float)
        assert recovery_time > 1
        self.recovery_time = recovery_time
        self.data = data

    def global_model(self):
        tau = self.recovery_time
        R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
        rho = pyro.sample("rho", dist.Beta(100, 100))
        return R0, tau, rho

    def initialize(self, params):
        # Start with a single infection.
        return {"S": self.population - 1, "I": 1}

    def transition(self, params, state, t):
        R0, tau, rho = params

        # Sample flows between compartments.
        S2I = pyro.sample("S2I_{}".format(t),
                          infection_dist(individual_rate=R0 / tau,
        I2R = pyro.sample("I2R_{}".format(t),
                          binomial_dist(state["I"], 1 / tau))

        # Update compartments with flows.
        state["S"] = state["S"] - S2I
        state["I"] = state["I"] + S2I - I2R

        # Condition on observations.
        t_is_observed = isinstance(t, slice) or t < self.duration
                    binomial_dist(S2I, rho),
                    obs=self.data[t] if t_is_observed else None)

Note that we’ve stored data in the model. These models have a scikit-learn like interface: we instantiate a model class with data, then call a .fit_*() method to train, then call .predict() on a trained model.

Note also that we’ve taken special care so that t can be either an integer or a slice. Under the hood, t is an integer during SMC initialization, a slice during SVI or MCMC inference, and an integer again during prediction.

Generating data

To check that our model generates plausible data, we can create a model with empty data and call the model’s .generate() method. This method first calls, .global_model(), then calls .initialize(), then calls .transition() once per time step (based on the length of our empty data.

population = 10000
recovery_time = 10.
empty_data = [None] * 90
model = SimpleSIRModel(population, recovery_time, empty_data)

# We'll repeatedly generate data until a desired number of infections is found.
for attempt in range(100):
    synth_data = model.generate({"R0": 2.0})
    total_infections = synth_data["S2I"].sum().item()
    if 4000 <= total_infections <= 6000:
print("Simulated {} infections after {} attempts".format(total_infections, 1 + attempt))
Simulated 4055.0 infections after 6 attempts

The generated data contains both global variables and time series, packed into tensors.

for key, value in sorted(synth_data.items()):
    print("{}.shape = {}".format(key, tuple(value.shape)))
I.shape = (90,)
I2R.shape = (90,)
R0.shape = ()
S.shape = (90,)
S2I.shape = (90,)
obs.shape = (90,)
rho.shape = ()
for name, value in sorted(synth_data.items()):
    if value.dim():
        plt.plot(value, label=name)
plt.xlim(0, len(empty_data) - 1)
plt.ylim(0.8, None)
plt.xlabel("time step")
plt.title("Synthetic time series")


Next let’s recover estimates of the latent variables given only observations obs. To do this we’ll create a new model instance from the synthetic observations.

obs = synth_data["obs"]
model = SimpleSIRModel(population, recovery_time, obs)

The CompartmentalModel provides a number of inference algorithms. The cheapest and most scalable algorithm is SVI, avilable via the .fit_svi() method. This method returns a list of losses to help us diagnose convergence; the fitted parameters are stored in the model object.

losses = model.fit_svi(num_steps=101 if smoke_test else 2001,
INFO     Heuristic init: R0=1.83, rho=0.546
INFO     Running inference...
INFO     step 0 loss = 6.808
INFO     step 200 loss = 9.099
INFO     step 400 loss = 7.384
INFO     step 600 loss = 4.401
INFO     step 800 loss = 3.428
INFO     step 1000 loss = 3.242
INFO     step 1200 loss = 3.13
INFO     step 1400 loss = 3.016
INFO     step 1600 loss = 3.029
INFO     step 1800 loss = 3.05
INFO     step 2000 loss = 3.017
INFO     SVI took 12.7 seconds, 157.9 step/sec
CPU times: user 12.8 s, sys: 278 ms, total: 13.1 s
Wall time: 13.2 s
plt.figure(figsize=(8, 3))
plt.xlabel("SVI step")
plt.ylim(min(losses), max(losses[50:]));

After inference, samples of latent variables are stored in the .samples attribute. These are primarily for internal use, and do not contain the full set of latent variables.

for key, value in sorted(model.samples.items()):
    print("{}.shape = {}".format(key, tuple(value.shape)))
R0.shape = (100, 1)
auxiliary.shape = (100, 1, 2, 90)
rho.shape = (100, 1)


After inference we can both examine latent variables and forecast forward using the .predict() method. First let’s simply predict latent variables.

samples = model.predict()
INFO     Predicting latent variables for 90 time steps...
CPU times: user 113 ms, sys: 2.82 ms, total: 116 ms
Wall time: 115 ms
for key, value in sorted(samples.items()):
    print("{}.shape = {}".format(key, tuple(value.shape)))
I.shape = (100, 90)
I2R.shape = (100, 90)
R0.shape = (100, 1)
S.shape = (100, 90)
S2I.shape = (100, 90)
auxiliary.shape = (100, 1, 2, 90)
obs.shape = (100, 90)
rho.shape = (100, 1)
names = ["R0", "rho"]
fig, axes = plt.subplots(2, 1, figsize=(5, 5))
axes[0].set_title("Posterior estimates of global parameters")
for ax, name in zip(axes, names):
    truth = synth_data[name]
    sns.distplot(samples[name], ax=ax, label="posterior")
    ax.axvline(truth, color="k", label="truth")

Notice that while the inference recovers the basic reproductive number R0, it poorly estimates the response rate rho and underestimates its uncertainty. While perfect inference would provide better uncertainty estimates, the response rate is known to be difficult to recover from data. Ideally the model can either incorporate a narrower prior, either obtained by testing a random sample of the population, or by more accurate observations, e.g. counting deaths rather than confirmed infections.


We can forecast forward by passing a forecast argument to the .predict() method, specifying the number of time steps ahead we’d like to forecast. The returned sample will contain time values during both the first observed time interval (here 90 days) and the forecasted window (say 30 days).

samples = model.predict(forecast=30)
INFO     Predicting latent variables for 90 time steps...
INFO     Forecasting 30 steps ahead...
CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.96 µs
def plot_forecast(samples):
    duration = len(empty_data)
    forecast = samples["S"].size(-1) - duration
    num_samples = len(samples["R0"])

    time = torch.arange(duration + forecast)
    S2I = samples["S2I"]
    median = S2I.median(dim=0).values
    p05 = S2I.kthvalue(int(round(0.5 + 0.05 * num_samples)), dim=0).values
    p95 = S2I.kthvalue(int(round(0.5 + 0.95 * num_samples)), dim=0).values

    plt.figure(figsize=(8, 4))
    plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
    plt.plot(time, median, "r-", label="median")
    plt.plot(time[:duration], obs, "k.", label="observed")
    plt.plot(time[:duration], synth_data["S2I"], "k--", label="truth")
    plt.axvline(duration - 0.5, color="gray", lw=1)
    plt.xlim(0, len(time) - 1)
    plt.ylim(0, None)
    plt.xlabel("day after first infection")
    plt.ylabel("new infections per day")
    plt.title("New infections in population of {}".format(population))
    plt.legend(loc="upper left")


It looks like the mean field guide underestimates uncertainty. To improve uncertainty estimates we can instead try MCMC inference. In this simple model MCMC is only a small factor slower than SVI; in more complex models MCMC can be multiple orders of magnitude slower than SVI.

model = SimpleSIRModel(population, recovery_time, obs)
mcmc = model.fit_mcmc(num_samples=4 if smoke_test else 400,
INFO     Running inference...
Warmup:   0%|          | 0/800 [00:00, ?it/s]INFO        Heuristic init: R0=2.05, rho=0.437
Sample: 100%|██████████| 800/800 [01:44,  7.63it/s, step size=1.07e-01, acc. prob=0.890]
CPU times: user 1min 43s, sys: 1.2 s, total: 1min 44s
Wall time: 1min 44s

samples = model.predict(forecast=30)
INFO     Predicting latent variables for 90 time steps...
INFO     Forecasting 30 steps ahead...

Advanced modeling

So far we’ve seen how to create a simple univariate model, fit the model to data, and predict and forecast future data. Next let’s consider more advanced modeling techniques:

Regional models

Epidemiology models vary in their level of detail. At the coarse-grained extreme are univariate aggregate models as we saw above. At the fine-grained extreme are network models where each individual’s state is tracked and infections occur along edges of a sparse graph (pyro.contrib.epidemiology does not implement network models). We now consider an mid-level model where each of many regions (e.g. countries or zip codes) is tracked in aggregate, and infections occur both within regions and between pairs of regions. In Pyro we model multiple regions with a plate. Pyro’s CompartmentalModel class does not support general pyro.plate syntax, but it does support a single special self.region_plate for regional models. This plate is available iff a CompartmentalModel is initialized with a vector population, and the size of the region_plate will be the length of the population vector.

Let’s take a look at the example RegionalSIRModel:

class RegionalSIRModel(CompartmentalModel):
    def __init__(self, population, coupling, recovery_time, data):
        duration = len(data)
        num_regions, = population.shape
        assert coupling.shape == (num_regions, num_regions)
        assert (0 <= coupling).all()
        assert (coupling <= 1).all()
        assert isinstance(recovery_time, float)
        assert recovery_time > 1
        if isinstance(data, torch.Tensor):
            # Data tensors should be oriented as (time, region).
            assert data.shape == (duration, num_regions)
        compartments = ("S", "I")  # R is implicit.

        # We create a regional model by passing a vector of populations.
        super().__init__(compartments, duration, population, approximate=("I",))

        self.coupling = coupling
        self.recovery_time = recovery_time
        self.data = data

    def global_model(self):
        # Assume recovery time is a known constant.
        tau = self.recovery_time

        # Assume reproductive number is unknown but homogeneous.
        R0 = pyro.sample("R0", dist.LogNormal(0., 1.))

        # Assume response rate is heterogeneous and model it with a
        # hierarchical Gamma-Beta prior.
        rho_c1 = pyro.sample("rho_c1", dist.Gamma(10, 1))
        rho_c0 = pyro.sample("rho_c0", dist.Gamma(10, 1))
        with self.region_plate:
            rho = pyro.sample("rho", dist.Beta(rho_c1, rho_c0))

        return R0, tau, rho

    def initialize(self, params):
        # Start with a single infection in region 0.
        I = torch.zeros_like(self.population)
        I[0] += 1
        S = self.population - I
        return {"S": S, "I": I}

    def transition(self, params, state, t):
        R0, tau, rho = params

        # Account for infections from all regions. This uses approximate (point
        # estimate) counts I_approx for infection from other regions, but uses
        # the exact (enumerated) count I for infections from one's own region.
        I_coupled = state["I_approx"] @ self.coupling
        I_coupled = I_coupled + (state["I"] - state["I_approx"]) * self.coupling.diag()
        I_coupled = I_coupled.clamp(min=0)  # In case I_approx is negative.
        pop_coupled = self.population @ self.coupling

        with self.region_plate:
            # Sample flows between compartments.
            S2I = pyro.sample("S2I_{}".format(t),
                              infection_dist(individual_rate=R0 / tau,
            I2R = pyro.sample("I2R_{}".format(t),
                              binomial_dist(state["I"], 1 / tau))

            # Update compartments with flows.
            state["S"] = state["S"] - S2I
            state["I"] = state["I"] + S2I - I2R

            # Condition on observations.
            t_is_observed = isinstance(t, slice) or t < self.duration
                        binomial_dist(S2I, rho),
                        obs=self.data[t] if t_is_observed else None)

The main differences from the earlier univariate model are that: we assume population is a vector of length num_regions, we sample all compartmental variables and some global variables inside the region_plate, and we compute coupled vectors I_coupled and pop_coupled of the effective number of infected individuals and population accounting for both intra-region and inter-region infections. Among global variables we have chosen for demonstration purposes to make tau a fixed single number, R0 a single latent variable shared among all regions, and rho a local latent variable that can take a different value for each region. Note that while rho is not shared among regions, we have created a hierarchical model whereby rho’s parent variables are shared among regions. While some of our variables are region-global and some region-local, only the compartmental variables are both region-local and time-dependent; all other parameters are fixed for all time. See the heterogeneous models section below for time-dependent latent variables.

Note that Pyro’s enumerated MCMC strategy (.fit_mcmc() with num_quant_bins > 1) requires extra logic to use a mean-field approximation across compartments: we pass approximate=("I",) to the constructor and force compartements to iteract via state["I_approx"] rather than state["I"]. This code is not required for SVI inference or for moment-matched MCMC inference (.fit_mcmc() with the default num_quant_bins=0).

See the Epidemiology: regional models example for a demonstration of how to generate data, train, predict, and forecast with regional models.

Phylogenetic likelihoods

Epidemiological parameters can be difficult to identify from aggregate observations alone. However some parameters like the superspreading parameter k can be more accurately identified by combining aggregate count data with viral phylogenetic trees reconstructed from viral genetic sequencing data (Li et al. 2017). Pyro implements a CoalescentRateLikelihood class to compute a population likelihood p(I|phylogeny) given statistics of a phylogenetic tree (or a batch of tree samples). The statistics needed are exactly the times of each sampling event (i.e. when a viral genome was sequenced) and the times of genetic coalescent events in a binary phylogenetic tree; let us call these two vectors leaf_times and coal_times, respectively, where len(leaf_times) == 1 + len(coal_times) for binary trees. Pyro provides a helper bio_phylo_to_times() to extract these statistics from a Bio.Phylo tree objects; in turn Bio.Phylo can parse many file formats of phylogenetic trees.

Let’s take a look at the SuperspreadingSEIRModel which includes a phylogenetic likelihood. We’ll focus on the phylogenetic parts of the model:

class SuperspreadingSEIRModel(CompartmentalModel):
    def __init__(self, population, incubation_time, recovery_time, data, *,
                 leaf_times=None, coal_times=None):
        compartments = ("S", "E", "I")  # R is implicit.
        duration = len(data)
        super().__init__(compartments, duration, population)
        self.coal_likelihood = dist.CoalescentRateLikelihood(
            leaf_times, coal_times, duration)

    def transition(self, params, state, t):
        # Condition on observations.
        t_is_observed = isinstance(t, slice) or t < self.duration
        R = R0 * state["S"] / self.population
        coal_rate = R * (1. + 1. / k) / (tau_i * state["I"] + 1e-8)
                    self.coal_likelihood(coal_rate, t)
                    if t_is_observed else torch.tensor(0.))

We first constructed a CoalescentRateLikelihood object to be used throughout inference and prediction; this performs preprocessing work once so that it is cheap to evaluate self.coal_likelihood(...). Note that (leaf_times, coal_times) should be in units of time steps, the same time steps as the time index t and duration. Typically leaf_times are in [0, duration), but coal_times precede leaf_times (as points of common ancestry), and may be negative. The likelihood involves the coalescent rate coal_rate in a coalescent process; we can compute this from an epidemiological model. In this superspreading model coal_rate depends on the reproductive number R, the superspreading parameter k, the incubation time tau_i, and the current number of infected individuals state["I"] (Li et al. 2017).

Heterogeneous models

Epidemiological parameters often vary in time, due to human interventions, changes in weather, and other external factors. We can model real-valued time-varying latent variables in CompartmentalModel by moving static latent variables from .global_model() to .initialize() and .transition(). For example we can model a reproductive number under Brownian drift in log-space by initializing at a random R0 and multiplying by a drifting factor, as in the HeterogeneousSIRModel example:

class HeterogeneousSIRModel(CompartmentalModel):
    def global_model(self):
        tau = self.recovery_time
        R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
        rho = ...
        return R0, tau, rho

    def initialize(self, params):
        # Start with a single infection.
        # We also store the initial beta value in the state dict.
        return {"S": self.population - 1, "I": 1, "beta": torch.tensor(1.)}

    def transition(self, params, state, t):
        R0, tau, rho = params
        # Sample heterogeneous variables.
        # This assumes beta slowly drifts via Brownian motion in log space.
        beta = pyro.sample("beta_{}".format(t),
                           dist.LogNormal(state["beta"].log(), 0.1))
        Rt = pyro.deterministic("Rt_{}".format(t), R0 * beta)

        # Sample flows between compartments.
        S2I = pyro.sample("S2I_{}".format(t),
                          infection_dist(individual_rate=Rt / tau,
        # Update compartments and heterogeneous variables.
        state["S"] = state["S"] - S2I
        state["I"] = state["I"] + S2I - I2R
        state["beta"] = beta  # We store the latest beta value in the state dict.

Here we deterministically initialize a scale factor beta = 1 in .initialize() then let it drift via log-Brownian motion. We also need to update state["beta"] just as we update the compartmental variables. Now beta will be provided as a time series when we .predict(). While we could have written Rt = R0 * beta, we instead wrapped this computation in a pyro.deterministic thereby exposing Rt as another time series provided by .predict(). Note that we could have instead sampled R0 in .initialize() and let Rt drift directly, rather than introducing a scale factor beta. However separating the two into a non-centered form improves geometry (Betancourt and Girolami 2013).

It is also easy to pass in time-varying covariates as tensors, in the same way we have passed in data to the constructors of all example models. To predict the effects of different causal interventions, you can pass in a covariate that is longer than duration, run inference (looking only at the first [0,duration) entries), then mutate entries of the covariate after duration and generate different .predict()ions.

Complex compartment flow

The CompartmentalModel class assumes by default that the compartments are arranged linearly and terminate in an implicit terminal compartment named “R”, for example S-I-R, S-E-I-R or boxcar models like S-E1-E2-I1-I2-I3-R. To describe other more complex flows between compartments, you can override the .compute_flows() method. However currently there is no support for flows with undirected loops (e.g. S-I-S).

Let’s create a branching SIRD model with possible flows

S → I → R

As with other models, we’ll keep the “R” state implicit (although we could equally keep the “D” state implicit and the “R” state explicit). In the .compute_flows() method, we’ll input a pair of states and we’ll need to compute three flow values: S2I, I2R, and I2D.

class SIRDModel(CompartmentalModel):
    def __init__(self, population, data):
        compartments = ("S", "I", "D")
        duration = len(data)
        super().__init__(compartments, duration, population)
        self.data = data

    def compute_flows(self, prev, curr, t):
        S2I = prev["S"] - curr["S"]  # S can only go in one direction.
        I2D = curr["D"] - prev["D"]  # D can only have come from one direction.
        # Now by conservation at I, change + inflows + outflows = 0,
        # so we can solve for the single unknown I2R.
        I2R = prev["I"] - curr["I"] + S2I - I2D
        return {
            "S2I_{}".format(t): S2I,
            "I2D_{}".format(t): I2D,
            "I2R_{}".format(t): I2R,
    def transition(self, params, state, t):
        # Sample flows between compartments.
        S2I = pyro.sample("S2I_{}".format(t), ...)
        I2D = pyro.sample("I2D_{}".format(t), ...)
        I2R = pyro.sample("I2R_{}".format(t), ...)

        # Update compartments with flows.
        state["S"] = state["S"] - S2I
        state["I"] = state["I"] + S2I - I2D - I2R
        state["D"] = state["D"] + I2D

Note you can name the dict keys anything you want, as long as they match your sample statements in .transition() and you correctly reverse the flow computation in .transition(). During inference Pyro will check that the .compute_flows() and .transition() computations agree. Take care to avoid in-place PyTorch operations, since these can modify the tensors rather than the dictionary:

+ state["S"] = state["S"] - S2I  # Correct
- state["S"] -= S2I              # AVOID: may corrupt tensors

For a slightly more complex example, take a look at the SimpleSEIRDModel.


  1. Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017) “Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784

      1. Betancourt, Mark Girolami (2013) “Hamiltonian Monte Carlo for Hierarchical Models” https://arxiv.org/abs/1312.0906

[ ]: