Time Series Forecasting

In this tutorial, we will use numpyro to do forecasting. Specifically, we will replicate the Seasonal, Global Trend (SGT) model in Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications package. Data is the famous lynx time series, which contains annual numbers of lynx trappings from 1821 to 1934 in Canada.

[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd

import jax.numpy as np
from jax import lax, random, vmap
from jax.config import config; config.update("jax_platform_name", "cpu")

import numpyro.distributions as dist
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro.distributions.util import softmax
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc

Data

First, let’s import and take a look at the dataset.

[2]:
URL = "https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/datasets/lynx.csv"
lynx = pd.read_csv(URL, index_col=0)
data = lynx["value"].values
print("Length of time series:", data.shape[0])
plt.figure(figsize=(12, 6))
plt.plot(lynx["time"], data)
plt.xlabel("time", fontsize=14)
plt.ylabel("value", fontsize=14)
plt.show()
Length of time series: 114
_images/time_series_forecasting_5_1.png

The time series has a length of 114 (a data point for each year), and by looking at the plot, we can observe seasonality in this dataset, which is the recurrence of similar patterns at specific time periods. e.g. in this dataset, we observe a cyclical pattern every 10 years, but there is also a less obvious but clear spike in the number of trappings every 40 years. Let us see if we can model this effect in NumPyro.

In this tutorial, we will use the first 80 values for training and the last 34 values for testing.

[3]:
y_train, y_test = np.array(data[:80], dtype=np.float32), data[80:]

Model

The model we are going to use is called Seasonal, Global Trend, which when tested on 3003 time series of the M-3 competition, has been known to outperform other models originally participating in the competition:

\begin{equation*} \begin{gathered} \text{exp_val}_{t} = \text{level}_{t-1} + \text{coef_trend} \times \text{level}_{t-1}^{\text{pow_trend}} + \text{s}_t \times \text{level}_{t-1}^{\text{pow_season}}, \\ \sigma_{t} = \sigma \times \text{exp_val}_{t}^{\text{powx}} + \text{offset}, \\ y_{t} \sim \text{StudentT}(\nu, \text{exp_val}_{t}, \sigma_{t}), \end{gathered} \end{equation*}

where level and s follows the following recursion rules:

\begin{equation*} \begin{gathered} \text{level_p} = \begin{cases} y_t - \text{s}_t \times \text{level}_{t-1}^{\text{pow_season}} & \text{if } t \le \text{seasonality}, \\ \text{Average} \left[y(t - \text{seasonality} + 1), \ldots, y(t)\right] & \text{otherwise}, \end{cases} \\ \text{level}_{t} = \text{level_sm} \times \text{level_p} + (1 - \text{level_sm}) \times \text{level}_{t-1}, \\ \text{s}_{t + \text{seasonality}} = \text{s_sm} \times \frac{y_{t} - \text{level}_{t}}{\text{level}_{t-1}^{\text{pow_trend}}} + (1 - \text{s_sm}) \times \text{s}_{t}. \end{gathered} \end{equation*}

A more detailed explanation for SGT model can be found in this vignette from the authors of Rlgt package. Here we summarize the core ideas of this model: + Student’s t-distribution, which has heavier tails than normal distribution, is used for the likelihood. + The expected value exp_val consists of a trending component and a seasonal component: - The trend is governed by the map \(x \mapsto x + ax^b\), where \(x\) is level, \(a\) is coef_trend, and \(b\) is pow_trend. Note that when \(b \sim 0\), the trend is linear with \(a\) is the slope, and when \(b \sim 1\), the trend is exponential with \(a\) is the rate. So that function can cover a large family of trend. - When time changes, level and s are updated to new values. Coefficients level_sm and s_sm are used to make the transition smoothly. + When powx is near \(0\), the error \(\sigma_t\) will be nearly constant while when powx is near \(1\), the error will be propotional to the expected value. + There are several varieties of SGT. In this tutorial, we use generalized seasonality and seasonal average method.

Note that level and s are updated recursively while we collect the expected value at each time step. NumPyro uses JAX in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python’s for loop in the model will result in a long compilation time for the model, so we use jax.lax.scan instead. A detailed explanation for using this utility can be found in lax.scan documentation. Here we use it to collect expected values while the pair (level, s) plays the role of carrying state.

[4]:
def scan_exp_val(y, init_s, level_sm, s_sm, coef_trend, pow_trend, pow_season):
    seasonality = init_s.shape[0]

    def scan_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level ** pow_season
        exp_val = level + coef_trend * level ** pow_trend + season
        exp_val = np.clip(exp_val, a_min=0)

        moving_sum = moving_sum + y[t] - np.where(t >= seasonality, y[t - seasonality], 0.)
        level_p = np.where(t >= seasonality, moving_sum / seasonality, y[t] - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = np.clip(level, a_min=0)
        new_s = (s_sm * (y[t] - level) / season + (1 - s_sm)) * s[0]
        s = np.concatenate([s[1:], new_s[None]], axis=0)
        return (level, s, moving_sum), exp_val

    level_init = y[0]
    s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    (last_level, last_s, moving_sum), exp_vals = lax.scan(
        scan_fn, (level_init, s_init, moving_sum), np.arange(1, y.shape[0]))
    return exp_vals, last_level, last_s

With our utility function defined above, we are ready to specify the model using NumPyro primitives. In NumPyro, we use the primitive sample(name, prior) to declare a latent random variable with a corresponding prior. These primitives can have custom interpretations depending on the effect handlers that are used by NumPyro inference algorithms in the backend. e.g. we can condition on specific values using the substitute handler, or record values at these sample sites in the execution trace using the trace handler. Note that these details are not important for specifying the model, or running inference, but curious readers are encouraged to read the tutorial on effect handlers in Pyro.

[5]:
def sgt(y, seasonality):
    # heuristically, standard derivation of Cauchy prior depends on the max value of data
    cauchy_sd = np.max(y) / 150

    nu = sample("nu", dist.Uniform(2, 20))
    powx = sample("powx", dist.Uniform(0, 1))
    sigma = sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = sample("offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10,
                                                               scale=cauchy_sd))

    coef_trend = sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = sample("pow_trend_beta", dist.Beta(1, 1))
    # pow_trend takes values from -0.5 to 1
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = sample("pow_season", dist.Beta(1, 1))

    level_sm = sample("level_sm", dist.Beta(1, 2))
    s_sm = sample("s_sm", dist.Uniform(0, 1))
    init_s = sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    exp_val, last_level, last_s = scan_exp_val(
        y, init_s, level_sm, s_sm, coef_trend, pow_trend, pow_season)
    omega = sigma * exp_val ** powx + offset_sigma
    sample("y", dist.StudentT(nu, exp_val, omega), obs=y[1:])
    # we return last `level` and last `s` for forecasting
    return last_level, last_s

Note that all prior parameters are retrieved from this file in the original source.

Inference

First, we want to choose a good value for seasonality. Following the demo in Rlgt, we will set seasonality=38. Indeed, this value can be guessed by looking at the plot of the training data, where the second order seasonality effect has a periodicity around \(40\) years. Note that \(38\) is also one of the highest-autocorrelation lags.

[6]:
print("Lag values sorted according to their autocorrelation values:\n")
print(np.argsort(autocorrelation(y_train))[::-1])
Lag values sorted according to their autocorrelation values:

[ 0 67 57 38 68  1 29 58 37 56 28 10 19 39 66 78 47 77  9 79 48 76 30 18
 20 11 46 59 69 27 55 36  2  8 40 49 17 21 75 12 65 45 31 26  7 54 35 41
 50  3 22 60 70 16 44 13  6 25 74 53 42 32 23 43 51  4 15 14 34 24  5 52
 73 64 33 71 72 61 63 62]

HMC algorithms require a potential function and initial parameters for the latent variables to begin sampling. The utility initialize_model will help us convert our sgt model to a potential function where the inputs are unconstrained values of latent variables. Because the input lies in unconstrained space, we need the utility constrain_fn to transform unconstrained parameters back to the original constrained domain of these parameters, at the end of the MCMC run.

[7]:
init_params, potential_fn, constrain_fn = initialize_model(
    random.PRNGKey(2), sgt, y_train, seasonality=38)

Now, let us run mcmc (using the default No-U-Turn Sampler algorithm) with \(5000\) warmup steps and \(5000\) sampling steps. The returned value will be a collection of \(5000\) samples. We set target_accept_prob=0.9 and max_tree_depth=12 following the original implementation in Rlgt package.

[8]:
samples = mcmc(num_warmup=5000, num_samples=5000,
               init_params=init_params, potential_fn=potential_fn, constrain_fn=constrain_fn,
               target_accept_prob=0.9, max_tree_depth=12)
warmup: 100%|██████████| 5000/5000 [01:46<00:00, 46.90it/s, 63 steps of size 6.31e-02. acc. prob=0.89]
sample: 100%|██████████| 5000/5000 [00:08<00:00, 601.84it/s, 31 steps of size 6.31e-02. acc. prob=0.88]


                           mean         sd       5.5%      94.5%      n_eff       Rhat
          coef_trend      29.05      99.48     -90.26     153.76     754.24       1.00
           init_s[0]      87.21     102.38     -60.45     225.31    1633.31       1.01
           init_s[1]     -21.07      65.37    -124.43      77.13    2369.64       1.00
           init_s[2]      32.08      96.57    -101.80     173.98    1487.31       1.00
           init_s[3]     127.54     131.21     -68.64     296.50    1291.95       1.00
           init_s[4]     453.53     259.71      92.16     802.24    1457.89       1.00
           init_s[5]    1179.88     465.43     504.62    1834.83     694.00       1.00
           init_s[6]    1989.11     686.78     915.28    2923.90     498.53       1.01
           init_s[7]    3699.56    1142.89    1964.50    5326.27     435.87       1.01
           init_s[8]    2624.63     876.57    1317.56    3844.97     427.44       1.00
           init_s[9]     944.22     421.57     302.07    1544.14     909.94       1.00
          init_s[10]      47.97     102.08    -104.50     186.96    1964.59       1.00
          init_s[11]      -0.01      50.71     -78.38      69.25    2003.29       1.00
          init_s[12]      -8.22      65.30    -116.94      76.10    2014.81       1.00
          init_s[13]      66.01      96.55     -67.28     213.26    1737.48       1.00
          init_s[14]     340.37     252.00      -1.73     680.72    1422.52       1.00
          init_s[15]     957.13     385.42     388.26    1490.20     621.26       1.00
          init_s[16]    1264.06     489.91     566.11    1969.40     592.58       1.00
          init_s[17]    1371.54     546.35     524.25    2116.39     701.41       1.00
          init_s[18]     613.54     313.43     141.40    1035.21     985.18       1.00
          init_s[19]      16.65      82.58    -115.93     128.11    2403.60       1.00
          init_s[20]     -31.66      65.28    -136.92      61.54    1862.62       1.00
          init_s[21]     -14.43      41.98     -74.99      46.92     932.23       1.00
          init_s[22]       0.55      50.14     -72.97      59.02     792.24       1.00
          init_s[23]      40.38      84.25     -77.60     155.88    1865.66       1.00
          init_s[24]     527.29     322.62      43.85     972.25    1058.18       1.00
          init_s[25]     948.44     478.41     289.40    1584.94     917.39       1.00
          init_s[26]    1772.88     666.34     769.56    2735.86     594.66       1.01
          init_s[27]    1275.58     478.20     580.10    1930.74     493.07       1.01
          init_s[28]     224.28     191.67     -41.62     446.65     948.49       1.00
          init_s[29]      -9.04      80.04    -128.36     116.23    2328.98       1.00
          init_s[30]      -3.20      85.47    -132.15     116.85    1371.62       1.00
          init_s[31]     -36.78      67.18    -147.35      60.47    2969.14       1.00
          init_s[32]     -10.13      83.04    -127.72     120.45    2186.65       1.00
          init_s[33]     116.67     137.07     -68.39     314.83    1418.59       1.00
          init_s[34]     509.81     280.08     113.83     927.73    1112.61       1.00
          init_s[35]    1067.36     442.12     403.47    1642.18     588.45       1.01
          init_s[36]    1854.03     656.23     880.65    2806.18     564.89       1.01
          init_s[37]    1456.71     562.67     605.02    2228.10     639.97       1.00
            level_sm       0.00       0.00       0.00       0.00    3306.11       1.00
                  nu      12.21       4.74       5.77      20.00    2426.72       1.00
        offset_sigma      32.99      32.04       0.06      70.22    2223.79       1.00
          pow_season       0.09       0.04       0.02       0.15     193.70       1.01
      pow_trend_beta       0.24       0.16       0.00       0.46    1061.99       1.00
                powx       0.62       0.13       0.40       0.82     926.98       1.00
                s_sm       0.08       0.08       0.00       0.17    2740.97       1.00
               sigma       9.68       9.64       0.46      19.86    1540.91       1.00

Forecasting

Given samples from mcmc, we want to do forecasting for the testing dataset y_test. First, we will make some utilities to do forecasting given a sample. Note that to retrieve the last level and last s value, we substitute a sample to the model:

...    level, s = substitute(sgt, asample)(y, seasonality)
[9]:
# Ref: https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/forecast.rlgtfit.R
def sgt_forecast(future, asample, y, level, s):
    seasonality = s.shape[0]
    moving_sum = np.sum(y[-seasonality:])
    pow_trend = 1.5 * asample["pow_trend_beta"] - 0.5
    yfs = [0] * (seasonality + future)
    for t in range(future):
        season = s[0] * level ** asample["pow_season"]
        exp_val = level + asample["coef_trend"] * level ** pow_trend + season
        exp_val = np.clip(exp_val, a_min=0)
        omega = asample["sigma"] * exp_val ** asample["powx"] + asample["offset_sigma"]
        yf = sample("yf[{}]".format(t), dist.StudentT(asample["nu"], exp_val, omega))
        yf = np.clip(yf, a_min=1e-30)
        yfs[t] = yf

        moving_sum = moving_sum + yf - np.where(t >= seasonality,
                                                yfs[t - seasonality], y[-seasonality + t])
        level_p = moving_sum / seasonality
        level_tmp = asample["level_sm"] * level_p + (1 - asample["level_sm"]) * level
        level = np.where(level_tmp > 1e-30, level_tmp, level)
        # s is repeated instead of being updated
        s = np.concatenate([s[1:], s[:1]], axis=0)


def forecast(future, rng, asample, y, seasonality):
    level, s = substitute(sgt, asample)(y, seasonality)
    forecast_model = seed(sgt_forecast, rng)
    forecast_trace = trace(forecast_model).get_trace(future, asample, y, level, s)
    results = [np.clip(forecast_trace["yf[{}]".format(t)]["value"], a_min=1e-30)
               for t in range(future)]
    return np.stack(results, axis=0)

Then, we can use jax.vmap to get prediction given a collection of samples. This allows us to vectorize the computation across the test dataset which can be dramatically faster as compared to using for-loop to collect predictions per test data point.

[10]:
rngs = random.split(random.PRNGKey(3), samples["nu"].shape[0])
forecast_marginal = vmap(lambda rng, asample: forecast(
    len(y_test), rng, asample, y_train, seasonality=38))(rngs, samples)

Finally, let’s get sMAPE, root mean square error of the prediction, and visualize the result with the mean prediction and the 89% highest posterior density interval (HPDI).

[11]:
y_pred = np.mean(forecast_marginal, axis=0)
sMAPE = np.mean(np.abs(y_pred - y_test) / (y_pred + y_test)) * 200
msqrt = np.sqrt(np.mean((y_pred - y_test) ** 2))
print("sMAPE: {:.2f}, rmse: {:.2f}".format(sMAPE, msqrt))
sMAPE: 62.78, rmse: 1243.49
[12]:
plt.figure(figsize=(12, 6))
plt.plot(lynx["time"], data)
t_future = lynx["time"][80:]
hpd_low, hpd_high = hpdi(forecast_marginal)
plt.plot(t_future, y_pred, lw=2)
plt.fill_between(t_future, hpd_low, hpd_high, alpha=0.3)
plt.title("Forecasting lynx dataset with SGT model (89% HPDI)", fontsize=18)
plt.xlabel("time", fontsize=14)
plt.ylabel("value", fontsize=14)
plt.show()
_images/time_series_forecasting_30_0.png

As we can observe, the model has been able to learn both the first and second order seasonality effects, i.e. a cyclical pattern with a periodicity of around 10, as well as spikes that can be seen once every 40 or so years. Moreover, we not only have point estimates for the forecast but can also use the uncertainty estimates from the model to bound our forecasts.

Acknowledgements

We would like to thank Slawek Smyl for many helpful resources and suggestions. Fast inference would not have been possible without the support of JAX and the XLA teams, so we would like to thank them for providing such a great open-source platform for us to build on, and for their responsiveness in dealing with our feature requests and bug reports.

References

[1] Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications,     Slawek Smyl, Christoph Bergmeir, Erwin Wibowo, To Wang Ng, Trustees of Columbia University