Stochastic VolatilityΒΆ

View stochastic_volatility.py on github

import argparse

import numpy as onp

import jax.numpy as np
import jax.random as random
from jax.config import config as jax_config

import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect

"""
Generative model:

sigma ~ Exponential(50)
nu ~ Exponential(.1)
s_i ~ Normal(s_{i-1}, sigma - 2)
r_i ~ StudentT(nu, 0, exp(-2 s_i))

This example is from PyMC3 [1], which itself is adapted from the original experiment
from [2]. A discussion about translating this in Pyro appears in [3].

For more details, refer to:
1. *Stochastic Volatility Model*, https://docs.pymc.io/notebooks/stochastic_volatility.html
2. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
https://arxiv.org/pdf/1111.4246.pdf
3. Forum discussion, https://forum.pyro.ai/t/problems-transforming-a-pymc3-model-to-pyro-mcmc/208/14

"""

def model(returns):
step_size = sample('sigma', dist.Exponential(50.))
s = sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=np.shape(returns)[0]))
nu = sample('nu', dist.Exponential(.1))
return sample('r', dist.StudentT(df=nu, loc=0., scale=np.exp(-2*s)),
obs=returns)

def print_results(posterior, dates):
def _print_row(values, row_name=''):
quantiles = [0.2, 0.4, 0.5, 0.6, 0.8]
row_name_fmt = '{:>' + str(len(row_name)) + '}'
header_format = row_name_fmt + '{:>12}' * 5
row_format = row_name_fmt + '{:>12.3f}' * 5
columns = ['(p{})'.format(q * 100) for q in quantiles]
q_values = onp.quantile(values, quantiles, axis=0)
print(row_format.format(row_name, *q_values))
print('\n')

print('=' * 5, 'sigma', '=' * 5)
_print_row(posterior['sigma'])
print('=' * 5, 'nu', '=' * 5)
_print_row(posterior['nu'])
print('=' * 5, 'volatility', '=' * 5)
for i in range(0, len(dates), 180):
_print_row(np.exp(-2 * posterior['s'][:, i]), dates[i])

def main(args):
jax_config.update('jax_platform_name', args.device)
dates, returns = fetch()
init_rng, sample_rng = random.split(random.PRNGKey(args.rng))
init_params, potential_fn, constrain_fn = initialize_model(init_rng, model, returns)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup, rng=sample_rng)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda hmc_state: constrain_fn(hmc_state.z))
print_results(hmc_states, dates)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")