{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tracking an Unknown Number of Objects\n", "\n", "While SVI can be used to learn components and assignments of a mixture model, pyro.contrib.tracking provides more efficient inference algorithms to estimate assignments. This notebook demonstrates how to use the `MarginalAssignmentPersistent` inside SVI." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import torch\n", "from torch.distributions import constraints\n", "from matplotlib import pyplot\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.poutine as poutine\n", "from pyro.contrib.tracking.assignment import MarginalAssignmentPersistent\n", "from pyro.distributions.util import gather\n", "from pyro.infer import SVI, TraceEnum_ELBO\n", "from pyro.optim import Adam\n", "\n", "%matplotlib inline\n", "assert pyro.__version__.startswith('1.9.0')\n", "smoke_test = ('CI' in os.environ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's consider a model with deterministic dynamics, say sinusoids with known period but unknown phase and amplitude." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def get_dynamics(num_frames):\n", " time = torch.arange(float(num_frames)) / 4\n", " return torch.stack([time.cos(), time.sin()], -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's tricky to define a fully generative model, so instead we'll separate our data generation process `generate_data()` from a factor graph `model()` that will be used in inference." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def generate_data(args):\n", " # Object model.\n", " num_objects = int(round(args.expected_num_objects)) # Deterministic.\n", " states = dist.Normal(0., 1.).sample((num_objects, 2))\n", "\n", " # Detection model.\n", " emitted = dist.Bernoulli(args.emission_prob).sample((args.num_frames, num_objects))\n", " num_spurious = dist.Poisson(args.expected_num_spurious).sample((args.num_frames,))\n", " max_num_detections = int((num_spurious + emitted.sum(-1)).max())\n", " observations = torch.zeros(args.num_frames, max_num_detections, 1+1) # position+confidence\n", " positions = get_dynamics(args.num_frames).mm(states.t())\n", " noisy_positions = dist.Normal(positions, args.emission_noise_scale).sample()\n", " for t in range(args.num_frames):\n", " j = 0\n", " for i, e in enumerate(emitted[t]):\n", " if e:\n", " observations[t, j, 0] = noisy_positions[t, i]\n", " observations[t, j, 1] = 1\n", " j += 1\n", " n = int(num_spurious[t])\n", " if n:\n", " observations[t, j:j+n, 0] = dist.Normal(0., 1.).sample((n,))\n", " observations[t, j:j+n, 1] = 1\n", "\n", " return states, positions, observations" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def model(args, observations):\n", " with pyro.plate(\"objects\", args.max_num_objects):\n", " exists = pyro.sample(\"exists\",\n", " dist.Bernoulli(args.expected_num_objects / args.max_num_objects))\n", " with poutine.mask(mask=exists.bool()):\n", " states = pyro.sample(\"states\", dist.Normal(0., 1.).expand([2]).to_event(1))\n", " positions = get_dynamics(args.num_frames).mm(states.t())\n", " with pyro.plate(\"detections\", observations.shape[1]):\n", " with pyro.plate(\"time\", args.num_frames):\n", " # The combinatorial part of the log prob is approximated to allow independence.\n", " is_observed = (observations[..., -1] > 0)\n", " with poutine.mask(mask=is_observed):\n", " assign = pyro.sample(\"assign\",\n", " dist.Categorical(torch.ones(args.max_num_objects + 1)))\n", " is_spurious = (assign == args.max_num_objects)\n", " is_real = is_observed & ~is_spurious\n", " num_observed = is_observed.float().sum(-1, True)\n", " pyro.sample(\"is_real\",\n", " dist.Bernoulli(args.expected_num_objects / num_observed),\n", " obs=is_real.float())\n", " pyro.sample(\"is_spurious\",\n", " dist.Bernoulli(args.expected_num_spurious / num_observed),\n", " obs=is_spurious.float())\n", "\n", " # The remaining continuous part is exact.\n", " observed_positions = observations[..., 0]\n", " with poutine.mask(mask=is_real):\n", " bogus_position = positions.new_zeros(args.num_frames, 1)\n", " augmented_positions = torch.cat([positions, bogus_position], -1)\n", " predicted_positions = gather(augmented_positions, assign, -1)\n", " pyro.sample(\"real_observations\",\n", " dist.Normal(predicted_positions, args.emission_noise_scale),\n", " obs=observed_positions)\n", " with poutine.mask(mask=is_spurious):\n", " pyro.sample(\"spurious_observations\", dist.Normal(0., 1.),\n", " obs=observed_positions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This guide uses a smart assignment solver but a naive state estimator. A smarter implementation would use message passing also for state estimation, e.g. a Kalman filter-smoother." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def guide(args, observations):\n", " # Initialize states randomly from the prior.\n", " states_loc = pyro.param(\"states_loc\", lambda: torch.randn(args.max_num_objects, 2))\n", " states_scale = pyro.param(\"states_scale\",\n", " lambda: torch.ones(states_loc.shape) * args.emission_noise_scale,\n", " constraint=constraints.positive)\n", " positions = get_dynamics(args.num_frames).mm(states_loc.t())\n", "\n", " # Solve soft assignment problem.\n", " real_dist = dist.Normal(positions.unsqueeze(-2), args.emission_noise_scale)\n", " spurious_dist = dist.Normal(0., 1.)\n", " is_observed = (observations[..., -1] > 0)\n", " observed_positions = observations[..., 0].unsqueeze(-1)\n", " assign_logits = (real_dist.log_prob(observed_positions) -\n", " spurious_dist.log_prob(observed_positions) +\n", " math.log(args.expected_num_objects * args.emission_prob /\n", " args.expected_num_spurious))\n", " assign_logits[~is_observed] = -float('inf')\n", " exists_logits = torch.empty(args.max_num_objects).fill_(\n", " math.log(args.max_num_objects / args.expected_num_objects))\n", " assignment = MarginalAssignmentPersistent(exists_logits, assign_logits)\n", "\n", " with pyro.plate(\"objects\", args.max_num_objects):\n", " exists = pyro.sample(\"exists\", assignment.exists_dist, infer={\"enumerate\": \"parallel\"})\n", " with poutine.mask(mask=exists.bool()):\n", " pyro.sample(\"states\", dist.Normal(states_loc, states_scale).to_event(1))\n", " with pyro.plate(\"detections\", observations.shape[1]):\n", " with poutine.mask(mask=is_observed):\n", " with pyro.plate(\"time\", args.num_frames):\n", " assign = pyro.sample(\"assign\", assignment.assign_dist, infer={\"enumerate\": \"parallel\"})\n", "\n", " return assignment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll define a global config object to make it easy to port code to `argparse`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "args = type('Args', (object,), {}) # A fake ArgumentParser.parse_args() result.\n", "\n", "args.num_frames = 5\n", "args.max_num_objects = 3\n", "args.expected_num_objects = 2.\n", "args.expected_num_spurious = 1.\n", "args.emission_prob = 0.8\n", "args.emission_noise_scale = 0.1\n", "\n", "assert args.max_num_objects >= args.expected_num_objects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "generated 16 detections from 2 objects\n" ] } ], "source": [ "pyro.set_rng_seed(0)\n", "true_states, true_positions, observations = generate_data(args)\n", "true_num_objects = len(true_states)\n", "max_num_detections = observations.shape[1]\n", "assert true_states.shape == (true_num_objects, 2)\n", "assert true_positions.shape == (args.num_frames, true_num_objects)\n", "assert observations.shape == (args.num_frames, max_num_detections, 1+1)\n", "print(\"generated {:d} detections from {:d} objects\".format(\n", " (observations[..., -1] > 0).long().sum(), true_num_objects))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def plot_solution(message=''):\n", " assignment = guide(args, observations)\n", " states_loc = pyro.param(\"states_loc\")\n", " positions = get_dynamics(args.num_frames).mm(states_loc.t())\n", " pyplot.figure(figsize=(12,6)).patch.set_color('white')\n", " pyplot.plot(true_positions.numpy(), 'k--')\n", " is_observed = (observations[..., -1] > 0)\n", " pos = observations[..., 0]\n", " time = torch.arange(float(args.num_frames)).unsqueeze(-1).expand_as(pos)\n", " pyplot.scatter(time[is_observed].view(-1).numpy(),\n", " pos[is_observed].view(-1).numpy(), color='k', marker='+',\n", " label='observation')\n", " for i in range(args.max_num_objects):\n", " p_exist = assignment.exists_dist.probs[i].item()\n", " position = positions[:, i].detach().numpy()\n", " pyplot.plot(position, alpha=p_exist, color='C0')\n", " pyplot.title('Truth, observations, and predicted tracks ' + message)\n", " pyplot.plot([], 'k--', label='truth')\n", " pyplot.plot([], color='C0', label='prediction')\n", " pyplot.legend(loc='best')\n", " pyplot.xlabel('time step')\n", " pyplot.ylabel('position')\n", " pyplot.tight_layout()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pyro.set_rng_seed(1)\n", "pyro.clear_param_store()\n", "plot_solution('(before training)')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 loss = 89.270072937\n", "epoch 10 loss = 85.940826416\n", "epoch 20 loss = 86.1014556885\n", "epoch 30 loss = 83.8865127563\n", "epoch 40 loss = 85.354347229\n", "epoch 50 loss = 82.01512146\n", "epoch 60 loss = 78.1765365601\n", "epoch 70 loss = 78.0290603638\n", "epoch 80 loss = 74.915725708\n", "epoch 90 loss = 74.3280792236\n", "epoch 100 loss = 74.1109313965\n" ] } ], "source": [ "infer = SVI(model, guide, Adam({\"lr\": 0.01}), TraceEnum_ELBO(max_plate_nesting=2))\n", "losses = []\n", "for epoch in range(101 if not smoke_test else 2):\n", " loss = infer.step(args, observations)\n", " if epoch % 10 == 0:\n", " print(\"epoch {: >4d} loss = {}\".format(epoch, loss))\n", " losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pyplot.plot(losses);" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_solution('(after training)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }