\n", "\n", "$$\\log p_{\\theta}({\\bf x} | {\\rm Pa}_p ({\\bf x})) +\n", "\\sum_i \\log p_{\\theta}({\\bf z}_i | {\\rm Pa}_p ({\\bf z}_i)) \n", "- \\sum_i \\log q_{\\phi}({\\bf z}_i | {\\rm Pa}_q ({\\bf z}_i))$$\n", "\n", "where we've broken the log ratio $\\log p_{\\theta}({\\bf x}, {\\bf z})/q_{\\phi}({\\bf z})$ into an observation log likelihood piece and a sum over the different latent random variables $\\{{\\bf z}_i \\}$. We've also introduced the notation \n", "${\\rm Pa}_p (\\cdot)$ and ${\\rm Pa}_q (\\cdot)$ to denote the parents of a given random variable in the model and in the guide, respectively. (The reader might worry what the appropriate notion of dependency would be in the case of general stochastic functions; here we simply mean regular ol' dependency within a single execution trace). The point is that different terms in the cost function have different dependencies on the random variables $\\{ {\\bf z}_i \\}$ and this is something we can leverage.\n", "\n", "To make a long story short, for any non-reparameterizable latent random variable ${\\bf z}_i$ the surrogate objective is going to have a term \n", "\n", "$$\\log q_{\\phi}({\\bf z}_i) \\overline{f_{\\phi}({\\bf z})} $$\n", "\n", "It turns out that we can remove some of the terms in $\\overline{f_{\\phi}({\\bf z})}$ and still get an unbiased gradient estimator; furthermore, doing so will generally decrease the variance. In particular (see reference [4] for details) we can remove any terms in $\\overline{f_{\\phi}({\\bf z})}$ that are not downstream of the latent variable ${\\bf z}_i$ (downstream w.r.t. to the dependency structure of the guide). Note that this general trick—where certain random variables are dealt with analytically to reduce variance—often goes under the name of Rao-Blackwellization.\n", "\n", "In Pyro, all of this logic is taken care of automatically by the `SVI` class. In particular as long as we use a `TraceGraph_ELBO` loss, Pyro will keep track of the dependency structure within the execution traces of the model and guide and construct a surrogate objective that has all the unnecessary terms removed:\n", "\n", "```python\n", "svi = SVI(model, guide, optimizer, TraceGraph_ELBO())\n", "```\n", "\n", "Note that leveraging this dependency information takes extra computations, so `TraceGraph_ELBO` should only be used in the case where your model has non-reparameterizable random variables; in most applications `Trace_ELBO` suffices." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### An Example with Rao-Blackwellization:\n", "\n", "Suppose we have a gaussian mixture model with $K$ components. For each data point we: (i) first sample the component distribution $k \\in [1,...,K]$; and (ii) observe the data point using the $k^{\\rm th}$ component distribution. The simplest way to write down a model of this sort is as follows:\n", "\n", "```python\n", "ks = pyro.sample(\"k\", dist.Categorical(probs)\n", " .to_event(1))\n", "pyro.sample(\"obs\", dist.Normal(locs[ks], scale)\n", " .to_event(1),\n", " obs=data)\n", "```\n", "\n", "Since the user hasn't taken care to mark any of the conditional independencies in the model, the gradient estimator constructed by Pyro's `SVI` class is unable to take advantage of Rao-Blackwellization, with the result that the gradient estimator will tend to suffer from high variance. To address this problem the user needs to explicitly mark the conditional independence. Happily, this is not much work:\n", "\n", "\n", "```python\n", "# mark conditional independence \n", "# (assumed to be along the rightmost tensor dimension)\n", "with pyro.plate(\"foo\", data.size(-1)):\n", " ks = pyro.sample(\"k\", dist.Categorical(probs))\n", " pyro.sample(\"obs\", dist.Normal(locs[ks], scale),\n", " obs=data)\n", "``` \n", "\n", "That's all there is to it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Aside: Dependency tracking in Pyro\n", "\n", "Finally, a word about dependency tracking. Tracking dependency within a stochastic function that includes arbitrary Python code is a bit tricky. The approach currently implemented in Pyro is analogous to the one used in WebPPL (cf. reference [5]). Briefly, a conservative notion of dependency is used that relies on sequential ordering. If random variable ${\\bf z}_2$ follows ${\\bf z}_1$ in a given stochastic function then ${\\bf z}_2$ _may be_ dependent on ${\\bf z}_1$ and therefore _is_ assumed to be dependent. To mitigate the overly coarse conclusions that can be drawn by this kind of dependency tracking, Pyro includes constructs for declaring things as independent, namely `plate` and `markov` ([see the previous tutorial](svi_part_ii.ipynb)). For use cases with non-reparameterizable variables, it is therefore important for the user to make use of these constructs (when applicable) to take full advantage of the variance reduction provided by `SVI`. In some cases it may also pay to consider reordering random variables within a stochastic function (if possible).\n", "\n", "### Reducing Variance with Data-Dependent Baselines\n", "\n", "The second strategy for reducing variance in our ELBO gradient estimator goes under the name of baselines (see e.g. reference [6]). It actually makes use of the same bit of math that underlies the variance reduction strategy discussed above, except now instead of removing terms we're going to add terms. Basically, instead of removing terms with zero expectation that tend to _contribute_ to the variance, we're going to add specially chosen terms with zero expectation that work to _reduce_ the variance. As such, this is a control variate strategy.\n", "\n", "In more detail, the idea is to take advantage of the fact that for any constant $b$, the following identity holds\n", "\n", "$$\\mathbb{E}_{q_{\\phi}({\\bf z})} \\left [\\nabla_{\\phi}\n", "(\\log q_{\\phi}({\\bf z}) \\times b) \\right]=0$$\n", "\n", "This follows since $q(\\cdot)$ is normalized:\n", "\n", "$$\\mathbb{E}_{q_{\\phi}({\\bf z})} \\left [\\nabla_{\\phi}\n", "\\log q_{\\phi}({\\bf z}) \\right]=\n", " \\int \\!d{\\bf z} \\; q_{\\phi}({\\bf z}) \\nabla_{\\phi}\n", "\\log q_{\\phi}({\\bf z})=\n", " \\int \\! d{\\bf z} \\; \\nabla_{\\phi} q_{\\phi}({\\bf z})=\n", "\\nabla_{\\phi} \\int \\! d{\\bf z} \\; q_{\\phi}({\\bf z})=\\nabla_{\\phi} 1 = 0$$\n", "\n", "What this means is that we can replace any term\n", "\n", "$$\\log q_{\\phi}({\\bf z}_i) \\overline{f_{\\phi}({\\bf z})} $$\n", "\n", "in our surrogate objective with\n", "\n", "$$\\log q_{\\phi}({\\bf z}_i) \\left(\\overline{f_{\\phi}({\\bf z})}-b\\right) $$\n", "\n", "Doing so doesn't affect the mean of our gradient estimator but it does affect the variance. If we choose $b$ wisely, we can hope to reduce the variance. In fact, $b$ need not be a constant: it can depend on any of the random choices upstream (or sidestream) of ${\\bf z}_i$.\n", "\n", "#### Baselines in Pyro\n", "\n", "There are several ways the user can instruct Pyro to use baselines in the context of stochastic variational inference. Since baselines can be attached to any non-reparameterizable random variable, the current baseline interface is at the level of the `pyro.sample` statement. In particular the baseline interface makes use of an argument `baseline`, which is a dictionary that specifies baseline options. Note that it only makes sense to specify baselines for sample statements within the guide (and not in the model).\n", "\n", "##### Decaying Average Baseline\n", "\n", "The simplest baseline is constructed from a running average of recent samples of $\\overline{f_{\\phi}({\\bf z})}$. In Pyro this kind of baseline can be invoked as follows\n", "\n", "```python\n", "z = pyro.sample(\"z\", dist.Bernoulli(...), \n", " infer=dict(baseline={'use_decaying_avg_baseline': True,\n", " 'baseline_beta': 0.95}))\n", "```\n", "\n", "The optional argument `baseline_beta` specifies the decay rate of the decaying average (default value: `0.90`).\n", "\n", "#### Neural Baselines\n", "\n", "In some cases a decaying average baseline works well. In others using a baseline that depends on upstream randomness is crucial for getting good variance reduction. A powerful approach for constructing such a baseline is to use a neural network that can be adapted during the course of learning. Pyro provides two ways to specify such a baseline (for an extended example see the [AIR tutorial](air.ipynb)).\n", "\n", "First the user needs to decide what inputs the baseline is going to consume (e.g. the current datapoint under consideration or the previously sampled random variable). Then the user needs to construct a `nn.Module` that encapsulates the baseline computation. This might look something like\n", "\n", "```python\n", "class BaselineNN(nn.Module):\n", " def __init__(self, dim_input, dim_hidden):\n", " super().__init__()\n", " self.linear = nn.Linear(dim_input, dim_hidden)\n", " # ... finish initialization ...\n", "\n", " def forward(self, x):\n", " hidden = self.linear(x)\n", " # ... do more computations ...\n", " return baseline\n", "```\n", "\n", "Then, assuming the BaselineNN object `baseline_module` has been initialized somewhere else, in the guide we'll have something like\n", "\n", "```python\n", "def guide(x): # here x is the current mini-batch of data\n", " pyro.module(\"my_baseline\", baseline_module)\n", " # ... other computations ...\n", " z = pyro.sample(\"z\", dist.Bernoulli(...), \n", " infer=dict(baseline={'nn_baseline': baseline_module,\n", " 'nn_baseline_input': x}))\n", "```\n", "\n", "Here the argument `nn_baseline` tells Pyro which `nn.Module` to use to construct the baseline. On the backend the argument `nn_baseline_input` is fed into the forward method of the module to compute the baseline $b$. Note that the baseline module needs to be registered with Pyro with a `pyro.module` call so that Pyro is aware of the trainable parameters within the module.\n", "\n", "Under the hood Pyro constructs a loss of the form \n", "\n", "$${\\rm baseline\\; loss} \\equiv\\left(\\overline{f_{\\phi}({\\bf z})} - b \\right)^2$$\n", "\n", "which is used to adapt the parameters of the neural network. There's no theorem that suggests this is the optimal loss function to use in this context (it's not), but in practice it can work pretty well. Just as for the decaying average baseline, the idea is that a baseline that can track the mean $\\overline{f_{\\phi}({\\bf z})}$ will help reduce the variance. Under the hood `SVI` takes one step on the baseline loss in conjunction with a step on the ELBO. \n", "\n", "Note that in practice it can be important to use a different set of learning hyperparameters (e.g. a higher learning rate) for baseline parameters. In Pyro this can be done as follows:\n", "\n", "```python\n", "def per_param_args(param_name):\n", " if 'baseline' in param_name:\n", " return {\"lr\": 0.010}\n", " else:\n", " return {\"lr\": 0.001}\n", " \n", "optimizer = optim.Adam(per_param_args)\n", "```\n", "\n", "Note that in order for the overall procedure to be correct the baseline parameters should only be optimized through the baseline loss. Similarly the model and guide parameters should only be optimized through the ELBO. To ensure that this is the case under the hood `SVI` detaches the baseline $b$ that enters the ELBO from the autograd graph. Also, since the inputs to the neural baseline may depend on the parameters of the model and guide, the inputs are also detached from the autograd graph before they are fed into the neural network. \n", "\n", "Finally, there is an alternate way for the user to specify a neural baseline. Simply use the argument `baseline_value`:\n", "\n", "```python\n", "b = # do baseline computation\n", "z = pyro.sample(\"z\", dist.Bernoulli(...), \n", " infer=dict(baseline={'baseline_value': b}))\n", "```\n", "\n", "This works as above, except in this case it's the user's responsibility to make sure that any autograd tape connecting $b$ to the parameters of the model and guide has been cut. Or to say the same thing in language more familiar to PyTorch users, any inputs to $b$ that depend on $\\theta$ or $\\phi$ need to be detached from the autograd graph with `detach()` statements.\n", "\n", "#### A complete example with baselines\n", "\n", "Recall that in the [first SVI tutorial](svi_part_i.ipynb) we considered a bernoulli-beta model for coin flips. Because the beta random variable is non-reparameterizable (or rather not easily reparameterizable), the corresponding ELBO gradients can be quite noisy. In that context we dealt with this problem by using a Beta distribution that provides (approximate) reparameterized gradients. Here we showcase how a simple decaying average baseline can reduce the variance in the case where the Beta distribution is treated as non-reparameterized (so that the ELBO gradient estimator is of the score function type). While we're at it, we also use `plate` to write our model in a fully vectorized manner.\n", "\n", "Instead of directly comparing gradient variances, we're going to see how many steps it takes for SVI to converge. Recall that for this particular model (because of conjugacy) we can compute the exact posterior. So to assess the utility of baselines in this context, we setup the following simple experiment. We initialize the guide at a specified set of variational parameters. We then do SVI until the variational parameters have gotten to within a fixed tolerance of the parameters of the exact posterior. We do this both with and without the decaying average baseline. We then compare the number of gradient steps we needed in the two cases. Here's the complete code:\n", "\n", "(_Since apart from the use of_ `plate` _and_ `use_decaying_avg_baseline`, _this code is very similar to the code in parts I and II of the SVI tutorial, we're not going to go through the code line by line._)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.distributions.constraints as constraints\n", "import pyro\n", "import pyro.distributions as dist\n", "# Pyro also has a reparameterized Beta distribution so we import\n", "# the non-reparameterized version to make our point\n", "from pyro.distributions.testing.fakes import NonreparameterizedBeta\n", "import pyro.optim as optim\n", "from pyro.infer import SVI, TraceGraph_ELBO\n", "import sys\n", "\n", "assert pyro.__version__.startswith('1.7.0')\n", "\n", "# this is for running the notebook in our testing framework\n", "smoke_test = ('CI' in os.environ)\n", "max_steps = 2 if smoke_test else 10000\n", "\n", "\n", "def param_abs_error(name, target):\n", " return torch.sum(torch.abs(target - pyro.param(name))).item()\n", "\n", "\n", "class BernoulliBetaExample:\n", " def __init__(self, max_steps):\n", " # the maximum number of inference steps we do\n", " self.max_steps = max_steps\n", " # the two hyperparameters for the beta prior\n", " self.alpha0 = 10.0\n", " self.beta0 = 10.0\n", " # the dataset consists of six 1s and four 0s\n", " self.data = torch.zeros(10)\n", " self.data[0:6] = torch.ones(6)\n", " self.n_data = self.data.size(0)\n", " # compute the alpha parameter of the exact beta posterior\n", " self.alpha_n = self.data.sum() + self.alpha0\n", " # compute the beta parameter of the exact beta posterior\n", " self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)\n", " # initial values of the two variational parameters\n", " self.alpha_q_0 = 15.0\n", " self.beta_q_0 = 15.0\n", "\n", " def model(self, use_decaying_avg_baseline):\n", " # sample `latent_fairness` from the beta prior\n", " f = pyro.sample(\"latent_fairness\", dist.Beta(self.alpha0, self.beta0))\n", " # use plate to indicate that the observations are\n", " # conditionally independent given f and get vectorization\n", " with pyro.plate(\"data_plate\"):\n", " # observe all ten datapoints using the bernoulli likelihood\n", " pyro.sample(\"obs\", dist.Bernoulli(f), obs=self.data)\n", "\n", " def guide(self, use_decaying_avg_baseline):\n", " # register the two variational parameters with pyro\n", " alpha_q = pyro.param(\"alpha_q\", torch.tensor(self.alpha_q_0),\n", " constraint=constraints.positive)\n", " beta_q = pyro.param(\"beta_q\", torch.tensor(self.beta_q_0),\n", " constraint=constraints.positive)\n", " # sample f from the beta variational distribution\n", " baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,\n", " 'baseline_beta': 0.90}\n", " # note that the baseline_dict specifies whether we're using\n", " # decaying average baselines or not\n", " pyro.sample(\"latent_fairness\", NonreparameterizedBeta(alpha_q, beta_q),\n", " infer=dict(baseline=baseline_dict))\n", "\n", " def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):\n", " # clear the param store in case we're in a REPL\n", " pyro.clear_param_store()\n", " # setup the optimizer and the inference algorithm\n", " optimizer = optim.Adam({\"lr\": .0005, \"betas\": (0.93, 0.999)})\n", " svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())\n", " print(\"Doing inference with use_decaying_avg_baseline=%s\" % use_decaying_avg_baseline)\n", "\n", " # do up to this many steps of inference\n", " for k in range(self.max_steps):\n", " svi.step(use_decaying_avg_baseline)\n", " if k % 100 == 0:\n", " print('.', end='')\n", " sys.stdout.flush()\n", "\n", " # compute the distance to the parameters of the true posterior\n", " alpha_error = param_abs_error(\"alpha_q\", self.alpha_n)\n", " beta_error = param_abs_error(\"beta_q\", self.beta_n)\n", "\n", " # stop inference early if we're close to the true posterior\n", " if alpha_error < tolerance and beta_error < tolerance:\n", " break\n", "\n", " print(\"\\nDid %d steps of inference.\" % k)\n", " print((\"Final absolute errors for the two variational parameters \" +\n", " \"were %.4f & %.4f\") % (alpha_error, beta_error))\n", "\n", "# do the experiment\n", "bbe = BernoulliBetaExample(max_steps=max_steps)\n", "bbe.do_inference(use_decaying_avg_baseline=True)\n", "bbe.do_inference(use_decaying_avg_baseline=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Sample output:**\n", "```\n", "Doing inference with use_decaying_avg_baseline=True\n", "....................\n", "Did 1932 steps of inference.\n", "Final absolute errors for the two variational parameters were 0.7997 & 0.0800\n", "Doing inference with use_decaying_avg_baseline=False\n", "..................................................\n", "Did 4908 steps of inference.\n", "Final absolute errors for the two variational parameters were 0.7991 & 0.2532\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this particular run we can see that baselines roughly halved the number of steps of SVI we needed to do. The results are stochastic and will vary from run to run, but this is an encouraging result. This is a pretty contrived example, but for certain model and guide pairs, baselines can provide a substantial win. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "[1] `Automated Variational Inference in Probabilistic Programming`,\n", "

\n", "David Wingate, Theo Weber\n", "\n", "[2] `Black Box Variational Inference`,

\n", "Rajesh Ranganath, Sean Gerrish, David M. Blei\n", "\n", "[3] `Auto-Encoding Variational Bayes`,

\n", "Diederik P Kingma, Max Welling\n", "\n", "[4] `Gradient Estimation Using Stochastic Computation Graphs`,\n", "

\n", " John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel\n", " \n", "[5] `Deep Amortized Inference for Probabilistic Programs`\n", "

\n", "Daniel Ritchie, Paul Horsfall, Noah D. Goodman\n", "\n", "[6] `Neural Variational Inference and Learning in Belief Networks`\n", "

\n", "Andriy Mnih, Karol Gregor" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 2 }