# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import argparse
import logging
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, backtest
from pyro.ops.tensor_utils import periodic_cumsum, periodic_repeat
logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)
def preprocess(args):
"""
Extract a tensor of (arrivals,departures) to Embarcadero station.
"""
print("Loading data")
dataset = load_bart_od()
# The full dataset has all station->station ridership counts for all of 50
# train stations. In this simple example we will model only the aggretate
# counts to and from a single station, Embarcadero.
i = dataset["stations"].index("EMBR")
arrivals = dataset["counts"][:, :, i].sum(-1)
departures = dataset["counts"][:, i, :].sum(-1)
data = torch.stack([arrivals, departures], dim=-1)
print(f"Loaded data of shape {tuple(data.shape)}")
# This simple example uses no covariates, so we will construct a
# zero-element tensor of the correct length as empty covariates.
covariates = torch.zeros(len(data), 0)
return data, covariates
# We define a model by subclassing the ForecastingModel class and implementing
# a single .model() method.
class Model(ForecastingModel):
# The .model() method inputs two tensors: a fake tensor zero_data that is
# the same size and dtype as the real data (but of course the generative
# model shouldn't depend on the value of the data it generates!), and a
# tensor of covariates. Our simple model depends on no covariates, so we
# simply pass in an empty tensor (see the preprocess() function above).
def model(self, zero_data, covariates):
period = 24 * 7
duration, dim = zero_data.shape[-2:]
assert dim == 2 # Data is bivariate: (arrivals, departures).
# Sample global parameters.
noise_scale = pyro.sample(
"noise_scale", dist.LogNormal(torch.full((dim,), -3.0), 1.0).to_event(1)
)
assert noise_scale.shape[-1:] == (dim,)
trans_timescale = pyro.sample(
"trans_timescale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)
)
assert trans_timescale.shape[-1:] == (dim,)
trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim,))
assert trans_loc.shape[-1:] == (dim,)
trans_scale = pyro.sample(
"trans_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)
)
trans_corr = pyro.sample("trans_corr", dist.LKJCholesky(dim, torch.ones(())))
trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
assert trans_scale_tril.shape[-2:] == (dim, dim)
obs_scale = pyro.sample(
"obs_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)
)
obs_corr = pyro.sample("obs_corr", dist.LKJCholesky(dim, torch.ones(())))
obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
assert obs_scale_tril.shape[-2:] == (dim, dim)
# Note the initial seasonality should be sampled in a plate with the
# same dim as the time_plate, dim=-1. That way we can repeat the dim
# below using periodic_repeat().
with pyro.plate("season_plate", period, dim=-1):
season_init = pyro.sample(
"season_init", dist.Normal(torch.zeros(dim), 1).to_event(1)
)
assert season_init.shape[-2:] == (period, dim)
# Sample independent noise at each time step.
with self.time_plate:
season_noise = pyro.sample(
"season_noise", dist.Normal(0, noise_scale).to_event(1)
)
assert season_noise.shape[-2:] == (duration, dim)
# Construct a prediction. This prediction has an exactly repeated
# seasonal part plus slow seasonal drift. We use two deterministic,
# linear functions to transform our diagonal Normal noise to nontrivial
# samples from a Gaussian process.
prediction = periodic_repeat(season_init, duration, dim=-2) + periodic_cumsum(
season_noise, period, dim=-2
)
assert prediction.shape[-2:] == (duration, dim)
# Construct a joint noise model. This model is a GaussianHMM, whose
# .rsample() and .log_prob() methods are parallelized over time; this
# this entire model is parallelized over time.
init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
trans_mat = trans_timescale.neg().exp().diag_embed()
trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril)
obs_mat = torch.eye(dim)
obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril)
noise_model = dist.GaussianHMM(
init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration
)
assert noise_model.event_shape == (duration, dim)
# The final statement registers our noise model and prediction.
self.predict(noise_model, prediction)
def main(args):
data, covariates = preprocess(args)
# We will model positive count data by log1p-transforming it into real
# valued data. But since we want to evaluate back in the count domain, we
# will also define a transform to apply during evaluation, transforming
# from real back to count-valued data. Truth is mapped by the log1p()
# inverse expm1(), but the prediction will be sampled from a Poisson
# distribution.
data = data.log1p()
def transform(pred, truth):
pred = torch.poisson(pred.clamp(min=1e-4).expm1())
truth = truth.expm1()
return pred, truth
# The backtest() function automatically trains and evaluates our model on
# different windows of data.
forecaster_options = {
"num_steps": args.num_steps,
"learning_rate": args.learning_rate,
"log_every": args.log_every,
"dct_gradients": args.dct,
}
metrics = backtest(
data,
covariates,
Model,
train_window=args.train_window,
test_window=args.test_window,
stride=args.stride,
num_samples=args.num_samples,
forecaster_options=forecaster_options,
)
for name in ["mae", "rmse", "crps"]:
values = [m[name] for m in metrics]
mean = np.mean(values)
std = np.std(values)
print("{} = {:0.3g} +- {:0.3g}".format(name, mean, std))
return metrics
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="Bart Ridership Forecasting Example")
parser.add_argument("--train-window", default=2160, type=int)
parser.add_argument("--test-window", default=336, type=int)
parser.add_argument("--stride", default=168, type=int)
parser.add_argument("-n", "--num-steps", default=501, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--dct", action="store_true")
parser.add_argument("--num-samples", default=100, type=int)
parser.add_argument("--log-every", default=50, type=int)
parser.add_argument("--seed", default=1234567890, type=int)
args = parser.parse_args()
main(args)