{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Kalman Filter\n", "\n", "Kalman filters are linear models for state estimation of dynamic systems [1]. They have been the de facto standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process. They use a \"observe, predict, correct\" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters with learnable parameters using the pyro.contrib.tracking [library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)\n", "\n", "## Dynamic process\n", "\n", "To start, consider this simple motion model:\n", "\n", "$$X_{k+1} = FX_k + \\mathbf{W}_k$$\n", "$$\\mathbf{Z}_k = HX_k + \\mathbf{V}_k$$\n", "\n", "where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\\mathbf{W}_k$ and $\\mathbf{V}_k$ are independent noise processes (ie $\\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Kalman Update\n", "At each time step, we perform a prediction for the mean and covariance:\n", "$$\\hat{X}_k = F\\hat{X}_{k-1}$$\n", "$$\\hat{P}_k = FP_{k-1}F^T + Q$$\n", "\n", "and a correction for the measurement:\n", "\n", "$$K_k = \\hat{P}_k H^T(H\\hat{P}_k H^T + R)^{-1}$$\n", "$$X_k = \\hat{X}_k + K_k(z_k - H\\hat{X}_k)$$\n", "$$P_k = (I-K_k H)\\hat{P}_k$$\n", "\n", "where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.\n", "\n", "For an in-depth derivation, see \$2\$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Nonlinear Estimation: Extended Kalman Filter\n", "\n", "What if our system is non-linear, eg in GPS navigation? Consider the following non-linear system:\n", "\n", "$$X_{k+1} = \\mathbf{f}(X_k) + \\mathbf{W}_k$$\n", "$$\\mathbf{Z}_k = \\mathbf{h}(X_k) + \\mathbf{V}_k$$\n", "\n", "Notice that $\\mathbf{f}$ and $\\mathbf{h}$ are now (smooth) non-linear functions.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).\n", "\n", "$$f(X_k, k) \\approx f(x_k^R, k) + \\mathbf{H}_k(X_k - x_k^R) + \\cdots$$\n", "\n", "where $\\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms. At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled by Pyro under the hood), and use the result to perform a prediction and update.\n", "\n", "Omitting the derivations, the modification to the above predictions are now:\n", "$$\\hat{X}_k \\approx \\mathbf{f}(X_{k-1}^R)$$\n", "$$\\hat{P}_k = \\mathbf{H}_\\mathbf{f}(X_{k-1})P_{k-1}\\mathbf{H}_\\mathbf{f}^T(X_{k-1}) + Q$$\n", "\n", "and the updates are now:\n", "\n", "$$X_k \\approx \\hat{X}_k + K_k\\big(z_k - \\mathbf{h}(\\hat{X}_k)\\big)$$\n", "$$K_k = \\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) \\Big(\\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) + R_k\\Big)^{-1}$$\n", "$$P_k = \\big(I - K_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\big)\\hat{P}_K$$\n", "\n", "In Pyro, all we need to do is create an EKFState object and use its predict and update methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.\n", "\n", "As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import math\n", "\n", "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer.autoguide import AutoDelta\n", "from pyro.optim import Adam\n", "from pyro.infer import SVI, Trace_ELBO, config_enumerate\n", "from pyro.contrib.tracking.extended_kalman_filter import EKFState\n", "from pyro.contrib.tracking.distributions import EKFDistribution\n", "from pyro.contrib.tracking.dynamic_models import NcvContinuous\n", "from pyro.contrib.tracking.measurements import PositionMeasurement\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.8.3')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dt = 1e-2\n", "num_frames = 10\n", "dim = 4\n", "\n", "# Continuous model\n", "ncv = NcvContinuous(dim, 2.0)\n", "\n", "# Truth trajectory\n", "xs_truth = torch.zeros(num_frames, dim)\n", "# initial direction\n", "theta0_truth = 0.0\n", "# initial state\n", "with torch.no_grad():\n", " xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])\n", " for frame_num in range(1, num_frames):\n", " # sample independent process noise\n", " dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n", " xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's specify the measurements. Notice that we only measure the positions of the particle." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Measurements\n", "measurements = []\n", "mean = torch.zeros(2)\n", "# no correlations\n", "cov = 1e-5 * torch.eye(2)\n", "with torch.no_grad():\n", " # sample independent measurement noise\n", " dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n", " # compute measurement means\n", " zs = xs_truth[:, :2] + dzs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a [Delta autoguide](http://docs.pyro.ai/en/dev/infer.autoguide.html#autodelta) to learn MAP estimates of the position and measurement covariances. The EKFDistribution computes the joint log density of all of the EKF states given a tensor of sequential measurements." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def model(data):\n", " # a HalfNormal can be used here as well\n", " R = pyro.sample('pv_cov', dist.HalfCauchy(2e-6)) * torch.eye(4)\n", " Q = pyro.sample('measurement_cov', dist.HalfCauchy(1e-6)) * torch.eye(2)\n", " # observe the measurements\n", " pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,\n", " Q, time_steps=num_frames),\n", " obs=data)\n", " \n", "guide = AutoDelta(model) # MAP estimation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "optim = pyro.optim.Adam({'lr': 2e-2})\n", "svi = SVI(model, guide, optim, loss=Trace_ELBO(retain_graph=True))\n", "\n", "pyro.set_rng_seed(0)\n", "pyro.clear_param_store()\n", "\n", "for i in range(250 if not smoke_test else 2):\n", " loss = svi.step(zs)\n", " if not i % 10:\n", " print('loss: ', loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# retrieve states for visualization\n", "R = guide()['pv_cov'] * torch.eye(4)\n", "Q = guide()['measurement_cov'] * torch.eye(2)\n", "ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)\n", "states= ekf_dist.filter_states(zs)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "\$1\$ Kalman, R. E. *A New Approach to Linear Filtering and Prediction Problems.* 1960\n", "\n", "\$2\$ Welch, Greg, and Bishop, Gary. *An Introduction to the Kalman Filter.* 2006.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }