{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Attend Infer Repeat\n", "\n", "In this tutorial we will implement the model and inference strategy described in \"Attend, Infer, Repeat:\n", "Fast Scene Understanding with Generative Models\" (AIR) [1] and apply it to the multi-mnist dataset.\n", "\n", "A [standalone implementation](https://github.com/pyro-ppl/pyro/tree/dev/examples/air) is also available." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "%pylab inline\n", "import os\n", "from collections import namedtuple\n", "import pyro\n", "import pyro.optim as optim\n", "from pyro.infer import SVI, TraceGraph_ELBO\n", "import pyro.distributions as dist\n", "import pyro.poutine as poutine\n", "import pyro.contrib.examples.multi_mnist as multi_mnist\n", "import torch\n", "import torch.nn as nn\n", "from torch.nn.functional import relu, sigmoid, softplus, grid_sample, affine_grid\n", "import numpy as np\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.3.0')\n", "pyro.enable_validation(True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "The model described in [1] is a generative model of scenes. In this tutorial we will use it to model images from a dataset that is similar to the multi-mnist dataset in [1]. Here are some data points from this data set:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "keep_output": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeAAAABvCAYAAAA0RRMsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAC7NJREFUeJzt3X9oldUfwPH3bEtDaqO5SpYRGTVKMCxDCpaRs18qRj+QKcsKUxdSwgjL0iR/tYp+/RNZVlhWUIZRRKwwIiPBlUmxrVxkac7KylZj2bb7/eNyn2Zfnb92n/Ps3vcLLtw9u8/O2eHc+7nn85xznoJUKoUkSYrXoNAVkCQpHxmAJUkKwAAsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQEUxllYQUGB224doVQqVXC059reR+5Y2hts86NhH4+X7R2vvtrbEbAkSQEYgCVJCsAALElSAAZgSZICiHUSlnSkRo0aBcDkyZMZOXJk9Hzp0qUAPPXUU8HqJknHoiDO+wE7g+7I5fOMxdmzZ/PYY48BMHjw4P1+t3PnTgDOOOOMfi3TWdDxy+c+HoLtHS9nQUuSlDCOgBMuH7+trl69GoCpU6dSXFx8wNd0d3cD8O233wJQUVHRL2U7Ao5fPvbxkGzvePXV3gbghMu3N0tRURF///03AP/tmx9++CEA48ePj461t7cDUFJS0i/lG4Djl299vLdZs2Zx/vnnA3DXXXfFUmY+t3cIpqAlSUoYZ0EruBNOOIElS5YAMG3atOj47t27ue666wDYtGlTdDyTfgb44YcfYqql1H9WrFgBwMyZMykrKwPgu+++4/HHHw9ZLcXMFHTC5UO6qLKykg0bNkQ/t7S0AFBTU8PmzZuj40OHDgXgjz/+oKurC0h/gAG88sor/VIXU9Dxy4c+3lthYeEBL7NUVFSwbdu2rJefb+0dmiloSZISxhFwwuXyt9UFCxYAMGfOHEaMGBEdP+644w74+gkTJgDw3nvvuQ44h+RyH+9t3LhxAGzcuJFBg9Jjn56enuj3B+v3/S1f2jsp+mpvrwErmNraWgDKy8ujY3v37g1VHSmrPvvsMwCWLVvGwoULgXQKurGxMWS1FJApaEmSAnAErNhlZnqefvrpAOzZs4eqqioAtmzZcsjzBw0aREHBMWWKpdjt27cPgFWrVjFjxgwgfQklc1z5xwCsWD355JPMnTsX+HcG6Pr16w8r8GauofX09PDAAw9krY5SNj3yyCNs374dSAfgurq6wDVSKKagJUkKwBGwYpHZq3n69OnRDNDMns/z5s07rL8xefLk6LlpOw00mQzODTfcEL0HmpubaWtrC1ktBeQIWJKkABwBKxaZUW7vmyY88cQTAHR2dh7y/KqqKkaNGgWk74D01ltvZaGWUvZccMEFQHruQ2b97/XXXx9dD1YeSqVSsT2AlI8je+RCe5955pmp9vb2VHt7e6q7uzt6HMnfaGhoiM674447EtneSWrzgfTIh/aeNGlSqqOjI9XR0ZHq6upKjR49OjV69GjbOw8efbWnKWhJkgIwBa2smzhxIkOGDIl+/vHHHwEoLS0F0hOqMvf1hX8nbFVXV0c7BgF0dHQAsG7duqzXWepPO3bsYM+ePQAMHz6cL774InCNlATuBZ1wubJva+Z679y5c6M9bz///HMgvf3kjh07otfW1NQA+++T29nZyfr164F0YM4W94KOX6708b58//33+225Gte+zweSD+2dJN4NSZKkhHEEnHC59m21ubmZkSNHAkRrIf8rc7yzszO6b+q8efNYs2ZN1uvnCDh+udbHD2TNmjVMnDgRSF96KSwMd/UvH9o7SbwbkhKjoqKC22+/HYCTTz4ZgGHDhjF//vzoNffeey8AmzdvpqGhIf5KSv3s1Vdf5aKLLgJg165dgWujpDAFLUlSAKagE850UbxCpqC7u7sBmD17Ns8+++yxVGNAyZc+/vLLLwOwePFitm3bFqwe+dLeSdFXexuAE843S7ySEIBfe+21rM70Thr7eLxs73g5C1qSpIRxBHwIpaWlbN26FYDKykpaW1tjLd9vq/FKwggYwq4TjZt9PF62d7ycBX0MFi1axKpVqwBiD76SpNxlClqSpABMQR/E2WefDUBLS4vbxuURU9Dxs4/Hy/aOlynoo7Bo0aLQVZAk5TBT0JIkBTDgRsDjxo3j008/zWoZt912GzNmzADcNk6SlB2OgCVJCmDABODi4mKKi4t54YUXsl7WhAkTKCgooKCggKeffjrr5Um9ffDBB6GrICkGAyIFXVRUFK3FbWlpyXp51157bfT8999/z3p5Um9XXHFF6CpIisGAGQFLkpRLBsQ64F27dnHqqacCB7+Je39KpVL89ttvwL/3rA3FNXvxch3w/qqqqgCorq5m5syZQPr9AensUO/3R6bOPT09HMnnin08XrZ3vAbsOuCTTjoJSN+wPXMrrzj09PTw8MMPx1aeBEQz71966SX++usvAO68806A2G5POGbMGObPnw9ASUkJ11xzDQAdHR38+uuvABQUpD9PiouL9zv3/fffB6C+vp533303lvpKA5kpaEmSAkjsCHjt2rVMmzYNgLKyMvbs2ZP1MqdMmRI9f+edd7JentTbvn37AOjq6mLIkCEArFixAsj+CHjMmDEALFmyhPvvvx+ALVu29HlORUUFFRUVAGzYsIH77rsPgIaGhizWVModibsG/OCDDwJQV1dHV1cXACeeeOIh/3ZZWRmQ/hDbu3fvEdUrc25jYyMAy5cvT8zyI6/XxCvkNeCQ6uvrAWhqauL5558/rHMKCwujuRJNTU1cfPHFR1W2fTxetne8+mpvU9CSJAWQuBR0ZtIJwI033nhY5zzzzDNMnToVgD///JO7774bgNdff/2wzh8/fjwA5eXlANGaYylfZLJG9fX1XHbZZQAsWLCAtra2g56zevVqvvrqKwAmTZqU/UpKOSYRKehhw4YBsHHjxmjjiylTprB79+7/e+1ZZ53FQw89BMCVV14JQGlpKf/88w8AK1euZM6cOUB6FuehDB8+nHXr1gFEKbSkLAEB00Vxy9cUdMbQoUOjL7AzZsyIlhl98sknnHPOOUD6PQjwxhtvcNNNNx1zmfbxeNne8TIFLUlSwiRiBFxXVwek01+ZCVeZdZAZhYXpbHlra2s0Mq6pqQGgubmZwYMHA/DLL7+wdOlSgGik3Jerr76at99+e79jjoDzV76PgHs77bTT2LlzZ/RzZv1vZk3+LbfcEk2UPBb28XjZ3vFK/EYcmRmYjz766P8FXkjPjM6kxUpKSjjvvPMAmDVrFgC1tbU0NTVFv++9o9ChHH/88dHz2trao/sHpBzU1tYWLcebNGkSl1xyCUDWbwcq5QtT0JIkBZCIFPT27dsBGDt2LD/99FN0PLMVZWNjIyNHjgTS6ebM4v+MyspKPv7446zUOTTTRfEyBb2/TDbpm2++4Z577gHgzTff7Ncy7OPxsr3j5SQsSZISJhHXgEeMGAHApk2b+PLLL6PjY8eOBeCUU06JjpWXl3PrrbcC8PPPPwPk7OhXCiEz4bG2tpaPPvoIgJtvvpkLL7wwZLWknJOIAJwJsL3X7aZSKVpbW4F0gM7MaL700kuZPn06AIsXL465plJuq6ioiPZy/vrrr7n88suj3xmApf5lClqSpAASMQlLB+eEiXjl+ySs5557jquuugpIr5HfunUrAOeee260Dri5ublfy7SPx8v2jldf7W0ATjjfLPHK9wDc3d0dzXaur6/nxRdfBKC6upqioqKslGkfj5ftHS9nQUuSlDCOgBPOb6vxcgTcHa39bW9vjyZGrly5koULF2alTPt4vGzveCV+K0pJybB27Vqqq6uB9KqE5cuXA7Bs2bKQ1ZJykiloSZICMAWdcKaL4pXvKegQ7OPxsr3j5SQsSZISxgAsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQHEuhGHJElKcwQsSVIABmBJkgIwAEuSFIABWJKkAAzAkiQFYACWJCkAA7AkSQEYgCVJCsAALElSAAZgSZICMABLkhSAAViSpAAMwJIkBWAAliQpAAOwJEkBGIAlSQrAACxJUgAGYEmSAjAAS5IUgAFYkqQADMCSJAVgAJYkKQADsCRJAfwPVUxC5Ncuf1kAAAAASUVORK5CYII=\n", "text/plain": [ "