Logistic growth models of SARS-CoV-2 lineage proportions¶
This notebook explores logistic growth models, with the goal of inferring the differential growth rates of different SARS-CoV-2 lineages over time. Before this tutorial you may want to familiarize yourself with Pyro modeling basics and tensor shapes in Pyro and PyTorch.
WARNING: The purpose of this tutorial is to demonstrate Pyro’s modeling and inference syntax. Making reliable inferences about SARS-CoV-2 is not the purpose of this tutorial.
Table of contents¶
Overview¶
When different strains/lineages/variants of a virus like SARS-CoV-2 circulate in a population, those lineages that have the largest fitness will tend to dominate, and those lineages that are least fit will tend to be outcompeted by the fittest lineages. In this tutorial we set out to infer (differential) growth rates for different SARS-CoV-2 lineages using a spatio-temporal dataset of SARS-CoV-2 genetic sequences. We’ll start with the simplest possible model and then move on to more complex models with hierarchical structure.
[1]:
import os
import datetime
from functools import partial
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import ClippedAdam
import matplotlib as mpl
import matplotlib.pyplot as plt
if torch.cuda.is_available():
print("Using GPU")
torch.set_default_device("cuda")
else:
print("Using CPU")
smoke_test = ('CI' in os.environ) # for use in continuous integration testing
Using CPU
Loading data¶
Our data consist of a few million genetic sequences of SARS-CoV-2 viruses, clustered into PANGO lineages, and aggregated into a few hundred regions globally and into 28-day time bins. Preprocessing was performed by Nextstrain’s ncov tool, and aggregation was performed by the Broad Institute’s pyro-cov tool.
[2]:
from pyro.contrib.examples.nextstrain import load_nextstrain_counts
dataset = load_nextstrain_counts()
def summarize(x, name=""):
if isinstance(x, dict):
for k, v in sorted(x.items()):
summarize(v, name + "." + k if name else k)
elif isinstance(x, torch.Tensor):
print(f"{name}: {type(x).__name__} of shape {tuple(x.shape)} on {x.device}")
elif isinstance(x, list):
print(f"{name}: {type(x).__name__} of length {len(x)}")
else:
print(f"{name}: {type(x).__name__}")
summarize(dataset)
counts: Tensor of shape (27, 202, 1316) on cpu
features: Tensor of shape (1316, 2634) on cpu
lineages: list of length 1316
locations: list of length 202
mutations: list of length 2634
sparse_counts.index: Tensor of shape (3, 57129) on cpu
sparse_counts.total: Tensor of shape (27, 202) on cpu
sparse_counts.value: Tensor of shape (57129,) on cpu
start_date: datetime
time_step_days: int
In this tutorial our interest is in the 3-dimensional tensor of counts dataset["counts"]
, which has shape (T, R, L) where T
is the number of time bins, R
is the number of regions, and L
is the number of strains or PANGO lineages, and dataset["counts"][t,r,l]
is the number of samples in the corresponding time-region-location bin. The count data are heavily skewed towards a few large regions and dominant lineages like B.1.1.7
and B.1.617.2
.
[3]:
fig, axes = plt.subplots(1, 3, figsize=(8, 3), sharey=True)
for i, name in enumerate(["time bin", "location", "lineage"]):
counts = dataset["counts"].sum(list({0, 1, 2} - {i}))
Y = counts.sort(0, True).values
axes[i].plot(Y)
axes[i].set_xlim(0, None)
axes[0].set_ylim(1, None)
axes[i].set_yscale("log")
axes[i].set_xlabel(f"rank", fontsize=18)
axes[i].set_title(f"{len(Y)} {name}s")
axes[0].set_ylabel("# samples", fontsize=18)
plt.subplots_adjust(wspace=0.05);
Helpers for manipulating data¶
[4]:
def get_lineage_id(s):
"""Get lineage id from string name"""
return np.argmax(np.array([s]) == dataset['lineages'])
def get_location_id(s):
"""Get location id from string name"""
return np.argmax(np.array([s]) == dataset['locations'])
def get_aggregated_counts_from_locations(locations):
"""Get aggregated counts from a list of locations"""
return sum([dataset['counts'][:, get_location_id(loc)] for loc in locations])
start = dataset["start_date"]
step = datetime.timedelta(days=dataset["time_step_days"])
date_range = np.array([start + step * t for t in range(len(dataset["counts"]))])
A first model¶
First let’s zoom-in on Massachusetts and a few surrounding states:
[5]:
northeast_states = ['USA / Massachusetts',
'USA / New York',
'USA / Connecticut',
'USA / New Hampshire',
'USA / Vermont',
'USA / New Jersey',
'USA / Maine',
'USA / Rhode Island',
'USA / Pennsylvania']
northeast_counts = get_aggregated_counts_from_locations(northeast_states)
Next let’s extract sublineages corresponding to two of the WHO variants of concern, Alpha and Delta:
[6]:
# The Alpha and Delta variants include many PANGO lineages, which we need to aggregate.
Q_lineages = [lin for lin in dataset['lineages'] if lin[:2] == 'Q.']
AY_lineages = [lin for lin in dataset['lineages'] if lin[:3] == 'AY.']
alpha_lineages = ['B.1.1.7'] + Q_lineages
delta_lineages = ['B.1.617.2'] + AY_lineages
alpha_ids = [get_lineage_id(lin) for lin in alpha_lineages]
delta_ids = [get_lineage_id(lin) for lin in delta_lineages]
alpha_counts = northeast_counts[:, alpha_ids].sum(-1)
delta_counts = northeast_counts[:, delta_ids].sum(-1)
[7]:
# Let's combine the counts into a single tensor
alpha_delta_counts = torch.stack([alpha_counts, delta_counts]).T
print(alpha_delta_counts.shape)
torch.Size([27, 2])
Next let’s plot the time series of count proportions (Alpha vs. Delta):
[8]:
# We skip the first year or so of the pandemic when Alpha and Delta are not present
start_time = 13
total_counts = (alpha_counts + delta_counts)[start_time:]
dates = date_range[start_time:]
plt.figure(figsize=(7, 4))
plt.plot(dates, alpha_counts[start_time:] / total_counts,
label='Alpha')
plt.plot(dates, delta_counts[start_time:] / total_counts,
label='Delta')
plt.xlim(min(dates), max(dates))
plt.ylabel("Proportion", fontsize=18)
plt.xticks(rotation=90)
plt.gca().xaxis.set_major_locator(mpl.dates.MonthLocator())
plt.gca().xaxis.set_major_formatter(mpl.dates.DateFormatter("%b %Y"))
plt.legend(fontsize=18)
plt.tight_layout()
We see that at first Alpha was dominant, but then Delta started outcompeting it until Delta became dominant.
Model definition¶
Instead of attempting to model how the total number of observed sequences varies as a function of time (which depends on complex human behavior), we instead model the proportion of sequences at each time step that are Alpha versus Delta. In other words if we observed 8 Alpha lineages and 2 Delta lineages at a given time step, we model the proportions 80% and 20% instead of the raw counts 8 and 2. To do this we use a Logistic Growth model with a Multinomial distribution as the likelihood.
[9]:
def basic_model(counts):
T, L = counts.shape
# Define plates over lineage and time
lineage_plate = pyro.plate("lineages", L, dim=-1)
time_plate = pyro.plate("time", T, dim=-2)
# Define a growth rate (i.e. slope) and an init (i.e. intercept) for each lineage
with lineage_plate:
rate = pyro.sample("rate", dist.Normal(0, 1))
init = pyro.sample("init", dist.Normal(0, 1))
# We measure time in units of the SARS-CoV-2 generation time of 5.5 days
time = torch.arange(float(T)) * dataset["time_step_days"] / 5.5
# Assume lineages grow linearly in logit space
logits = init + rate * time[:, None]
# We use the softmax function (the multivariate generalization of the
# sigmoid function) to define normalized probabilities from the logits
probs = torch.softmax(logits, dim=-1)
assert probs.shape == (T, L)
# Observe counts via a multinomial likelihood.
with time_plate:
pyro.sample(
"obs",
dist.Multinomial(probs=probs.unsqueeze(-2), validate_args=False),
obs=counts.unsqueeze(-2),
)
Let’s look at the graphical structure of our model
[10]:
pyro.render_model(partial(basic_model, alpha_delta_counts))
[10]:
Define a helper for fitting models¶
[11]:
def fit_svi(model, lr=0.1, num_steps=1001, log_every=250):
pyro.clear_param_store() # clear parameters from previous runs
pyro.set_rng_seed(20211214)
if smoke_test:
num_steps = 2
# Define a mean field guide (i.e. variational distribution)
guide = AutoNormal(model, init_scale=0.01)
optim = ClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / num_steps)})
svi = SVI(model, guide, optim, Trace_ELBO())
# Train (i.e. do ELBO optimization) for num_steps iterations
losses = []
for step in range(num_steps):
loss = svi.step()
losses.append(loss)
if step % log_every == 0:
print(f"step {step: >4d} loss = {loss:0.6g}")
# Plot to assess convergence.
plt.figure(figsize=(6, 3))
plt.plot(losses)
plt.xlabel("SVI step", fontsize=18)
plt.ylabel("ELBO loss", fontsize=18)
plt.tight_layout()
return guide
Let’s fit basic_model
and inspect the results¶
[12]:
%%time
# We truncate the data to the period with non-zero counts
guide = fit_svi(partial(basic_model, alpha_delta_counts[13:]), num_steps=1501)
step 0 loss = 103782
step 250 loss = 3373.05
step 500 loss = 1299.14
step 750 loss = 524.81
step 1000 loss = 304.319
step 1250 loss = 278.005
step 1500 loss = 261.731
CPU times: user 4.69 s, sys: 29.6 ms, total: 4.72 s
Wall time: 4.73 s
Let’s inspect the posterior means of our latent parameters:¶
[13]:
for k, v in guide.median().items():
print(k, v.data.cpu().numpy())
rate [-0.27021623 0.27021623]
init [ 8.870546 -8.870401]
As expected the Delta lineage (corresponding to index 1
) has a differential growth rate advantage with respect to the Alpha lineage (corresponding to index 0
) :
[14]:
print("Multiplicative advantage: {:.2f}".format(
np.exp(guide.median()['rate'][1] - guide.median()['rate'][0])))
Multiplicative advantage: 1.72
This seems like it might be an overestimate. Can we get better estimates by modeling each spatial region individually?
A regional model¶
Instead of focusing on northeastern US states we now consider the entire global dataset and do not aggregate across regions.
[15]:
# First extract the data we want to use
alpha_counts = dataset['counts'][:, :, alpha_ids].sum(-1)
delta_counts = dataset['counts'][:, :, delta_ids].sum(-1)
counts = torch.stack([alpha_counts, delta_counts], dim=-1)
print("counts.shape: ", counts.shape)
print(f"number of regions: {counts.size(1)}")
counts.shape: torch.Size([27, 202, 2])
number of regions: 202
[16]:
# We skip the first year or so of the pandemic when Alpha and Delta are not present
start_time = 13
total_counts = (alpha_counts + delta_counts)[start_time:]
dates = date_range[start_time:]
plt.figure(figsize=(7, 4))
plt.plot(dates, delta_counts[start_time:] / total_counts, color="C1", lw=1, alpha=0.5)
plt.xlim(min(dates), max(dates))
plt.ylabel("Proportion", fontsize=18)
plt.xticks(rotation=90)
plt.gca().xaxis.set_major_locator(mpl.dates.MonthLocator())
plt.gca().xaxis.set_major_formatter(mpl.dates.DateFormatter("%b %Y"))
plt.title(f"Delta/(Alpha+Delta) in {counts.size(1)} regions", fontsize=18)
plt.tight_layout()
[17]:
# Model lineage proportions in each region as multivariate logistic growth
def regional_model(counts):
T, R, L = counts.shape
# Now we also define a region plate in addition to the time/lineage plates
lineage_plate = pyro.plate("lineages", L, dim=-1)
region_plate = pyro.plate("region", R, dim=-2)
time_plate = pyro.plate("time", T, dim=-3)
# We use the same growth rate (i.e. slope) for each region
with lineage_plate:
rate = pyro.sample("rate", dist.Normal(0, 1))
# We allow the init to vary from region to region
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
with region_plate, lineage_plate:
init = pyro.sample("init", dist.Normal(0, init_scale))
# We measure time in units of the SARS-CoV-2 generation time of 5.5 days
time = torch.arange(float(T)) * dataset["time_step_days"] / 5.5
# Instead of using the softmax function we directly use the
# logits parameterization of the Multinomial distribution
logits = init + rate * time[:, None, None]
# Observe sequences via a multinomial likelihood.
with time_plate, region_plate:
pyro.sample(
"obs",
dist.Multinomial(logits=logits.unsqueeze(-2), validate_args=False),
obs=counts.unsqueeze(-2),
)
[18]:
pyro.render_model(partial(regional_model, counts))
[18]:
[19]:
%%time
guide = fit_svi(partial(regional_model, counts), num_steps=3001)
step 0 loss = 909278
step 250 loss = 509502
step 500 loss = 630927
step 750 loss = 501493
step 1000 loss = 1.04533e+06
step 1250 loss = 1.98151e+06
step 1500 loss = 328504
step 1750 loss = 279016
step 2000 loss = 310281
step 2250 loss = 217622
step 2500 loss = 204381
step 2750 loss = 176877
step 3000 loss = 152123
CPU times: user 17.3 s, sys: 965 ms, total: 18.3 s
Wall time: 17.3 s
[20]:
print("Multiplicative advantage: {:.2f}".format(
np.exp(guide.median()['rate'][1] - guide.median()['rate'][0])))
Multiplicative advantage: 1.17
Notice this is a lower estimate than the previous global estimate.
An alternative regional model¶
The regional model we defined above assumed that the rate
for each lineage did not vary between regions. Here we add additional hierarchical structure and allow the rate to vary from region to region.
[21]:
def regional_model2(counts):
T, R, L = counts.shape
lineage_plate = pyro.plate("lineages", L, dim=-1)
region_plate = pyro.plate("region", R, dim=-2)
time_plate = pyro.plate("time", T, dim=-3)
# We assume the init can vary a lot from region to region but
# that the rate varies considerably less.
rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
# As before each lineage has a latent growth rate
with lineage_plate:
rate_loc = pyro.sample("rate_loc", dist.Normal(0, 1))
# We allow the rate and init to vary from region to region
with region_plate, lineage_plate:
# The per-region per-lineage rate is governed by a hierarchical prior
rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))
init = pyro.sample("init", dist.Normal(0, init_scale))
# We measure time in units of the SARS-CoV-2 generation time of 5.5 days
time = torch.arange(float(T)) * dataset["time_step_days"] / 5.5
logits = init + rate * time[:, None, None]
# Observe sequences via a multinomial likelihood.
with time_plate, region_plate:
pyro.sample(
"obs",
dist.Multinomial(logits=logits.unsqueeze(-2), validate_args=False),
obs=counts.unsqueeze(-2),
)
[22]:
pyro.render_model(partial(regional_model2, counts))
[22]:
[23]:
%%time
guide = fit_svi(partial(regional_model2, counts), num_steps=3001)
step 0 loss = 2.14938e+06
step 250 loss = 1.44698e+06
step 500 loss = 1.24936e+06
step 750 loss = 701128
step 1000 loss = 602609
step 1250 loss = 530833
step 1500 loss = 454014
step 1750 loss = 450981
step 2000 loss = 384790
step 2250 loss = 340659
step 2500 loss = 305373
step 2750 loss = 279524
step 3000 loss = 262679
CPU times: user 25 s, sys: 1.05 s, total: 26 s
Wall time: 25.1 s
[24]:
print("Multiplicative advantage: {:.2f}".format(
(guide.median()['rate_loc'][1] - guide.median()['rate_loc'][0]).exp()))
Multiplicative advantage: 1.14
Generalizations¶
So far we’ve seen how to model two variants at a time either globally or split across multiple regions, and how to use pyro.plate to model multiple variants or regions or times.
What other models can you think of that might make epidemiological sense? Here are some ideas:
Can you create a model over more than two variants, or even over all PANGO lineages?
What variables should be shared across lineages, across regions, or over time?
How might you deal with changes of behavior over time, e.g. pandemic waves or vaccination?
For an example of a larger Pyro model using SARS-CoV-2 lineage data like this, see our paper “Analysis of 2.1 million SARS-CoV-2 genomes identifies mutations associated with transmissibility” (preprint | code), and also the Bayesian workflow tutuorial using a slightly smaller dataset.
[ ]: