Forecasting I: univariate, heavy tailed

This tutorial introduces the pyro.contrib.forecast module, a framework for forecasting with Pyro models. This tutorial covers only univariate models and simple likelihoods. This tutorial assumes the reader is already familiar with SVI and tensor shapes.

See also:

Summary

[1]:
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster, backtest, eval_crps
from pyro.infer.reparam import LocScaleReparam, StableReparam
from pyro.ops.tensor_utils import periodic_cumsum, periodic_repeat, periodic_features
from pyro.ops.stats import quantile
import matplotlib.pyplot as plt

%matplotlib inline
assert pyro.__version__.startswith('1.4.0')
pyro.enable_validation(True)
pyro.set_rng_seed(20200221)
[2]:
dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))
dict_keys(['stations', 'start_date', 'counts'])
torch.Size([78888, 50, 50])
12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK

Intro to Pyro’s forecasting framework

Pyro’s forecasting framework consists of: - a ForecastingModel base class, whose .model() method can be implemented for custom forecasting models, - a Forecaster class that trains and forecasts using ForecastingModels, and - a backtest() helper to evaluate models on a number of metrics.

Consider a simple univariate dataset, say weekly BART train ridership aggregated over all stations in the network. This data roughly logarithmic, so we log-transform for modeling.

[3]:
T, O, D = dataset["counts"].shape
data = dataset["counts"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()
data = data.unsqueeze(-1)
plt.figure(figsize=(9, 3))
plt.plot(data)
plt.title("Total weekly ridership")
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, len(data));
_images/forecasting_i_4_0.png

Let’s start with a simple log-linear regression model, with no trend or seasonality. Note that while this example is univariate, Pyro’s forecasting framework is multivariate, so we’ll often need to reshape using .unsqueeze(-1), .expand([1]), and .to_event(1).

[4]:
# First we need some boilerplate to create a class and define a .model() method.
class Model1(ForecastingModel):
    # We then implement the .model() method. Since this is a generative model, it shouldn't
    # look at data; however it is convenient to see the shape of data we're supposed to
    # generate, so this inputs a zeros_like(data) tensor instead of the actual data.
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)  # Should be 1 in this univariate tutorial.
        feature_dim = covariates.size(-1)

        # The first part of the model is a probabilistic program to create a prediction.
        # We use the zero_data as a template for the shape of the prediction.
        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))
        prediction = bias + (weight * covariates).sum(-1, keepdim=True)
        # The prediction should have the same shape as zero_data (duration, obs_dim),
        # but may have additional sample dimensions on the left.
        assert prediction.shape[-2:] == zero_data.shape

        # The next part of the model creates a likelihood or noise distribution.
        # Again we'll be Bayesian and write this as a probabilistic program with
        # priors over parameters.
        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(0, noise_scale)

        # The final step is to call the .predict() method.
        self.predict(noise_dist, prediction)

We can now train this model by creating a Forecaster object. We’ll split the data into [T0,T1) for training and [T1,T2) for testing.

[5]:
T0 = 0              # begining
T2 = data.size(-2)  # end
T1 = T2 - 52        # train/test split
[6]:
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.stack([time], dim=-1)
forecaster = Forecaster(Model1(), data[:T1], covariates[:T1], learning_rate=0.1)
INFO     step    0 loss = 484401
INFO     step  100 loss = 0.609042
INFO     step  200 loss = -0.535144
INFO     step  300 loss = -0.605789
INFO     step  400 loss = -0.59744
INFO     step  500 loss = -0.596203
INFO     step  600 loss = -0.614217
INFO     step  700 loss = -0.612415
INFO     step  800 loss = -0.613236
INFO     step  900 loss = -0.59879
INFO     step 1000 loss = -0.601271
CPU times: user 5.02 s, sys: 61.6 ms, total: 5.08 s
Wall time: 5.12 s

Next we can evaluate by drawing posterior samples from the forecaster, passing in full covariates but only partial data. We’ll use Pyro’s quantile() function to plot median and an 80% confidence interval. To evaluate fit we’ll use eval_crps() to compute Continuous Ranked Probability Score; this is an good metric to assess distributional fit of a heavy-tailed distribution.

[7]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
crps = eval_crps(samples, data[T1:])
print(samples.shape, p10.shape)

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");
torch.Size([1000, 52, 1]) torch.Size([52])
_images/forecasting_i_11_1.png

Zooming in to just the forecasted region, we see this model ignores seasonal behavior.

[8]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");
_images/forecasting_i_13_0.png

We could add a yearly seasonal component simply by adding new covariates (note we’ve already taken care in the model to handle feature_dim > 1).

[9]:
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.cat([time.unsqueeze(-1),
                        periodic_features(T2, 365.25 / 7)], dim=-1)
forecaster = Forecaster(Model1(), data[:T1], covariates[:T1], learning_rate=0.1)
INFO     step    0 loss = 53174.4
INFO     step  100 loss = 0.519148
INFO     step  200 loss = -0.0264822
INFO     step  300 loss = -0.314983
INFO     step  400 loss = -0.413243
INFO     step  500 loss = -0.487756
INFO     step  600 loss = -0.472516
INFO     step  700 loss = -0.595866
INFO     step  800 loss = -0.500985
INFO     step  900 loss = -0.558623
INFO     step 1000 loss = -0.589603
CPU times: user 5.74 s, sys: 88.5 ms, total: 5.83 s
Wall time: 5.89 s
[10]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
crps = eval_crps(samples, data[T1:])

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");
_images/forecasting_i_16_0.png
[11]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");
_images/forecasting_i_17_0.png

Time-local random variables: self.time_plate

So far we’ve seen the ForecastingModel.model() method and self.predict(). The last piece of forecasting-specific syntax is the self.time_plate context for time-local variables. To see how this works, consider changing our global linear trend model above to a local level model. Note the poutine.reparam() handler is a general Pyro inference trick, not specific to forecasting.

[12]:
class Model2(ForecastingModel):
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)
        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))

        # We'll sample a time-global scale parameter outside the time plate,
        # then time-local iid noise inside the time plate.
        drift_scale = pyro.sample("drift_scale",
                                  dist.LogNormal(-20, 5).expand([1]).to_event(1))
        with self.time_plate:
            # We'll use a reparameterizer to improve variational fit. The model would still be
            # correct if you removed this context manager, but the fit appears to be worse.
            with poutine.reparam(config={"drift": LocScaleReparam()}):
                drift = pyro.sample("drift", dist.Normal(zero_data, drift_scale).to_event(1))

        # After we sample the iid "drift" noise we can combine it in any time-dependent way.
        # It is important to keep everything inside the plate independent and apply dependent
        # transforms outside the plate.
        motion = drift.cumsum(-2)  # A Brownian motion.

        # The prediction now includes three terms.
        prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        # Construct the noise distribution and predict.
        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(0, noise_scale)
        self.predict(noise_dist, prediction)
[13]:
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = periodic_features(T2, 365.25 / 7)
forecaster = Forecaster(Model2(), data[:T1], covariates[:T1], learning_rate=0.1)
INFO     step    0 loss = 1.7326e+09
INFO     step  100 loss = 0.902688
INFO     step  200 loss = -0.0639999
INFO     step  300 loss = -0.102488
INFO     step  400 loss = -0.301241
INFO     step  500 loss = -0.404315
INFO     step  600 loss = -0.365754
INFO     step  700 loss = -0.429714
INFO     step  800 loss = -0.447207
INFO     step  900 loss = -0.515883
INFO     step 1000 loss = -0.519698
CPU times: user 9.11 s, sys: 66.9 ms, total: 9.18 s
Wall time: 9.21 s
[14]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
crps = eval_crps(samples, data[T1:])

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");
_images/forecasting_i_21_0.png
[15]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");
_images/forecasting_i_22_0.png

Heavy-tailed noise

Our final univariate model will generalize from Gaussian noise to heavy-tailed Stable noise. The only difference is the noise_dist which now takes two new parameters: stability determines tail weight and skew determines the relative size of positive versus negative spikes.

The Stable distribution is a natural heavy-tailed generalization of the Normal distribution, but it is difficult to work with due to its intractible density function. Pyro implements auxiliary variable methods for working with Stable distributions. To inform Pyro to use those auxiliary variable methods, we wrap the final line in poutine.reparam() effect handler that applies the StableReparam transform to the implicit observe site named “residual”. You can use Stable distributions for other sites by specifying config={"my_site_name": StableReparam()}.

[16]:
class Model3(ForecastingModel):
    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)
        feature_dim = covariates.size(-1)
        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))

        drift_scale = pyro.sample("drift_scale", dist.LogNormal(-20, 5).expand([1]).to_event(1))
        with self.time_plate:
            with poutine.reparam(config={"drift": LocScaleReparam()}):
                drift = pyro.sample("drift", dist.Normal(zero_data, drift_scale).to_event(1))
        motion = drift.cumsum(-2)  # A Brownian motion.

        prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        # The next part of the model creates a likelihood or noise distribution.
        # Again we'll be Bayesian and write this as a probabilistic program with
        # priors over parameters.
        stability = pyro.sample("noise_stability", dist.Uniform(1, 2).expand([1]).to_event(1))
        skew = pyro.sample("noise_skew", dist.Uniform(-1, 1).expand([1]).to_event(1))
        scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Stable(stability, skew, scale)

        # We need to use a reparameterizer to handle the Stable distribution.
        # Note "residual" is the name of Pyro's internal sample site in self.predict().
        with poutine.reparam(config={"residual": StableReparam()}):
            self.predict(noise_dist, prediction)
[17]:
%%time
pyro.set_rng_seed(2)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = periodic_features(T2, 365.25 / 7)
forecaster = Forecaster(Model3(), data[:T1], covariates[:T1], learning_rate=0.1)
for name, value in forecaster.guide.median().items():
    if value.numel() == 1:
        print("{} = {:0.4g}".format(name, value.item()))
INFO     step    0 loss = 5.92062e+07
INFO     step  100 loss = 13.5949
INFO     step  200 loss = 3.0411
INFO     step  300 loss = 0.866627
INFO     step  400 loss = 0.362264
INFO     step  500 loss = 0.0508628
INFO     step  600 loss = -0.236901
INFO     step  700 loss = -0.290881
INFO     step  800 loss = -0.242376
INFO     step  900 loss = -0.339689
INFO     step 1000 loss = -0.33147
bias = 14.64
drift_scale = 2.173e-08
noise_stability = 1.937
noise_skew = 0.0007298
noise_scale = 0.06047
CPU times: user 19 s, sys: 97.5 ms, total: 19 s
Wall time: 19.1 s
[18]:
samples = forecaster(data[:T1], covariates, num_samples=1000)
p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
crps = eval_crps(samples, data[T1:])

plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(data, 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(0, None)
plt.legend(loc="best");
_images/forecasting_i_26_0.png
[19]:
plt.figure(figsize=(9, 3))
plt.fill_between(torch.arange(T1, T2), p10, p90, color="red", alpha=0.3)
plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')
plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')
plt.title("Total weekly ridership (CRPS = {:0.3g})".format(crps))
plt.ylabel("log(# rides)")
plt.xlabel("Week after 2011-01-01")
plt.xlim(T1, None)
plt.legend(loc="best");
_images/forecasting_i_27_0.png

Backtesting

To compare our Gaussian Model2 and Stable Model3 we’ll use a simple backtesting() helper. This helper by default evaluates three metrics: CRPS assesses distributional accuracy of heavy-tailed data, MAE assesses point accuracy of heavy-tailed data, and RMSE assesses accuracy of Normal-tailed data. The one nuance here is to set warm_start=True to reduce the need for random restarts.

[20]:
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
windows2 = backtest(data, covariates, Model2,
                    min_train_window=104, test_window=52, stride=26,
                    forecaster_options={"learning_rate": 0.1, "log_every": 1000,
                                        "warm_start": True})
INFO     Training on window [0:104], testing on window [104:156]
INFO     step    0 loss = 3534.09
INFO     step 1000 loss = 0.11251
INFO     Training on window [0:130], testing on window [130:182]
INFO     step    0 loss = 0.238584
INFO     step 1000 loss = -0.184576
INFO     Training on window [0:156], testing on window [156:208]
INFO     step    0 loss = 0.62968
INFO     step 1000 loss = -0.0259982
INFO     Training on window [0:182], testing on window [182:234]
INFO     step    0 loss = 0.195288
INFO     step 1000 loss = -0.120416
INFO     Training on window [0:208], testing on window [208:260]
INFO     step    0 loss = 0.188322
INFO     step 1000 loss = -0.18523
INFO     Training on window [0:234], testing on window [234:286]
INFO     step    0 loss = 0.0471417
INFO     step 1000 loss = -0.185852
INFO     Training on window [0:260], testing on window [260:312]
INFO     step    0 loss = 0.00251847
INFO     step 1000 loss = -0.246146
INFO     Training on window [0:286], testing on window [286:338]
INFO     step    0 loss = -0.0702055
INFO     step 1000 loss = -0.25786
INFO     Training on window [0:312], testing on window [312:364]
INFO     step    0 loss = -0.133986
INFO     step 1000 loss = -0.375242
INFO     Training on window [0:338], testing on window [338:390]
INFO     step    0 loss = -0.167895
INFO     step 1000 loss = -0.331766
INFO     Training on window [0:364], testing on window [364:416]
INFO     step    0 loss = -0.270294
INFO     step 1000 loss = -0.438097
INFO     Training on window [0:390], testing on window [390:442]
INFO     step    0 loss = -0.297009
INFO     step 1000 loss = -0.473476
INFO     Training on window [0:416], testing on window [416:468]
INFO     step    0 loss = -0.398169
INFO     step 1000 loss = -0.502486
CPU times: user 1min 51s, sys: 724 ms, total: 1min 52s
Wall time: 1min 52s
[21]:
%%time
pyro.set_rng_seed(1)
pyro.clear_param_store()
windows3 = backtest(data, covariates, Model3,
                    min_train_window=104, test_window=52, stride=26,
                    forecaster_options={"learning_rate": 0.1, "log_every": 1000,
                                        "warm_start": True})
INFO     Training on window [0:104], testing on window [104:156]
INFO     step    0 loss = 1849.22
INFO     step 1000 loss = 0.543365
INFO     Training on window [0:130], testing on window [130:182]
INFO     step    0 loss = 2.51271
INFO     step 1000 loss = 0.0757928
INFO     Training on window [0:156], testing on window [156:208]
INFO     step    0 loss = 2.6663
INFO     step 1000 loss = 0.0912818
INFO     Training on window [0:182], testing on window [182:234]
INFO     step    0 loss = 1.97279
INFO     step 1000 loss = -0.00365819
INFO     Training on window [0:208], testing on window [208:260]
INFO     step    0 loss = 1.59146
INFO     step 1000 loss = -0.0871935
INFO     Training on window [0:234], testing on window [234:286]
INFO     step    0 loss = 1.34227
INFO     step 1000 loss = -0.103136
INFO     Training on window [0:260], testing on window [260:312]
INFO     step    0 loss = 1.21624
INFO     step 1000 loss = -0.214513
INFO     Training on window [0:286], testing on window [286:338]
INFO     step    0 loss = 1.0086
INFO     step 1000 loss = -0.272347
INFO     Training on window [0:312], testing on window [312:364]
INFO     step    0 loss = 0.962262
INFO     step 1000 loss = -0.293812
INFO     Training on window [0:338], testing on window [338:390]
INFO     step    0 loss = 0.598708
INFO     step 1000 loss = -0.190582
INFO     Training on window [0:364], testing on window [364:416]
INFO     step    0 loss = 0.719034
INFO     step 1000 loss = -0.362534
INFO     Training on window [0:390], testing on window [390:442]
INFO     step    0 loss = 0.353514
INFO     step 1000 loss = -0.431448
INFO     Training on window [0:416], testing on window [416:468]
INFO     step    0 loss = 0.402931
INFO     step 1000 loss = -0.48814
CPU times: user 4min, sys: 1.07 s, total: 4min 1s
Wall time: 4min 3s
[22]:
fig, axes = plt.subplots(3, figsize=(8, 6), sharex=True)
axes[0].set_title("Gaussian versus Stable accuracy over {} windows".format(len(windows2)))
axes[0].plot([w["crps"] for w in windows2], "b<", label="Gaussian")
axes[0].plot([w["crps"] for w in windows3], "r>", label="Stable")
axes[0].set_ylabel("CRPS")
axes[1].plot([w["mae"] for w in windows2], "b<", label="Gaussian")
axes[1].plot([w["mae"] for w in windows3], "r>", label="Stable")
axes[1].set_ylabel("MAE")
axes[2].plot([w["rmse"] for w in windows2], "b<", label="Gaussian")
axes[2].plot([w["rmse"] for w in windows3], "r>", label="Stable")
axes[2].set_ylabel("RMSE")
axes[0].legend(loc="best")
plt.tight_layout()
_images/forecasting_i_31_0.png

Note that RMSE is a poor metric for evaluating heavy-tailed data. Our stable model has such heavy tails that its variance is infinite, so we cannot expect RMSE to converge, hence occasional outlying points.

[ ]: