{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Normalizing Flows - Introduction (Part 1)\n", "\n", "This tutorial introduces Pyro's normalizing flow library. It is independent of much of Pyro, but users may want to read about distribution shapes in the [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html).\n", " \n", "## Introduction\n", "\n", "In standard probabilistic modeling practice, we represent our beliefs over unknown continuous quantities with simple parametric distributions like the normal, exponential, and Laplacian distributions. However, using such simple forms, which are commonly symmetric and unimodal (or have a fixed number of modes when we take a mixture of them), restricts the performance and flexibility of our methods. For instance, standard variational inference in the Variational Autoencoder uses independent univariate normal distributions to represent the variational family. The true posterior is neither independent nor normally distributed, which results in suboptimal inference and simplifies the model that is learnt. In other scenarios, we are likewise restricted by not being able to model multimodal distributions and heavy or light tails.\n", "\n", "Normalizing Flows \$1-4\$ are a family of methods for constructing flexible learnable probability distributions, often with neural networks, which allow us to surpass the limitations of simple parametric forms. Pyro contains state-of-the-art normalizing flow implementations, and this tutorial explains how you can use this library for learning complex models and performing flexible variational inference. We introduce the main idea of Normalizing Flows (NFs) and demonstrate learning simple univariate distributions with element-wise, multivariate, and conditional flows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Univariate Distributions\n", " \n", "### Background\n", " \n", "Normalizing Flows are a family of methods for constructing flexible distributions. Let's first restrict our attention to representing univariate distributions. The basic idea is that a simple source of noise, for example a variable with a standard normal distribution, $X\\sim\\mathcal{N}(0,1)$, is passed through a bijective (i.e. invertible) function, $g(\\cdot)$ to produce a more complex transformed variable $Y=g(X)$.\n", "\n", "For a given random variable, we typically want to perform two operations: sampling and scoring. Sampling $Y$ is trivial. First, we sample $X=x$, then calculate $y=g(x)$. Scoring $Y$, or rather, evaluating the log-density $\\log(p_Y(y))$, is more involved. How does the density of $Y$ relate to the density of $X$? We can use the substitution rule of integral calculus to answer this. Suppose we want to evaluate the expectation of some function of $X$. Then,\n", "\n", "\n", "\\begin{align}\n", "\\mathbb{E}_{p_X(\\cdot)}\\left[f(X)\\right] &= \\int_{\\text{supp}(X)}f(x)p_X(x)dx\\\\\n", "&= \\int_{\\text{supp}(Y)}f(g^{-1}(y))p_X(g^{-1}(y))\\left|\\frac{dx}{dy}\\right|dy\\\\\n", "&= \\mathbb{E}_{p_Y(\\cdot)}\\left[f(g^{-1}(Y))\\right],\n", "\\end{align}\n", "\n", "\n", "where $\\text{supp}(X)$ denotes the support of $X$, which in this case is $(-\\infty,\\infty)$. Crucially, we used the fact that $g$ is bijective to apply the substitution rule in going from the first to the second line. Equating the last two lines we get,\n", "\n", "\n", "\\begin{align}\n", "\\log(p_Y(y)) &= \\log(p_X(g^{-1}(y)))+\\log\\left(\\left|\\frac{dx}{dy}\\right|\\right)\\\\\n", "&= \\log(p_X(g^{-1}(y)))-\\log\\left(\\left|\\frac{dy}{dx}\\right|\\right).\n", "\\end{align}\n", "\n", "\n", "Inituitively, this equation says that the density of $Y$ is equal to the density at the corresponding point in $X$ plus a term that corrects for the warp in volume around an infinitesimally small length around $Y$ caused by the transformation.\n", "\n", "If $g$ is cleverly constructed (and we will see several examples shortly), we can produce distributions that are more complex than standard normal noise and yet have easy sampling and computationally tractable scoring. Moreover, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have $L$ transforms $g_{(0)}, g_{(1)},\\ldots,g_{(L-1)}$, then the log-density of the transformed variable $Y=(g_{(0)}\\circ g_{(1)}\\circ\\cdots\\circ g_{(L-1)})(X)$ is\n", "\n", "\n", "\\begin{align}\n", "\\log(p_Y(y)) &= \\log\\left(p_X\\left(\\left(g_{(L-1)}^{-1}\\circ\\cdots\\circ g_{(0)}^{-1}\\right)\\left(y\\right)\\right)\\right)+\\sum^{L-1}_{l=0}\\log\\left(\\left|\\frac{dg^{-1}_{(l)}(y_{(l)})}{dy'}\\right|\\right),\n", "%\\left( g^{(l)}(y^{(l)})\n", "%\\right).\n", "\\end{align}\n", "\n", "\n", "where we've defined $y_{(0)}=x$, $y_{(L-1)}=y$ for convenience of notation.\n", "\n", "In a latter section, we will see how to generalize this method to multivariate $X$. The field of Normalizing Flows aims to construct such $g$ for multivariate $X$ to transform simple i.i.d. standard normal noise into complex, learnable, high-dimensional distributions. The methods have been applied to such diverse applications as image modeling, text-to-speech, unsupervised language induction, data compression, and modeling molecular structures. As probability distributions are the most fundamental component of probabilistic modeling we will likely see many more exciting state-of-the-art applications in the near future." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fixed Univariate Transforms in Pyro\n", "\n", "PyTorch contains classes for representing *fixed* univariate bijective transformations, and sampling/scoring from transformed distributions derived from these. Pyro extends this with a comprehensive library of *learnable* univariate and multivariate transformations using the latest developments in the field. As Pyro imports all of PyTorch's distributions and transformations, we will work solely with Pyro. We also note that the NF components in Pyro can be used independently of the probabilistic programming functionality of Pyro, which is what we will be doing in the first two tutorials.\n", "\n", "Let us begin by showing how to represent and manipulate a simple transformed distribution,\n", "\n", "\n", "\\begin{align}\n", "X &\\sim \\mathcal{N}(0,1)\\\\\n", "Y &= \\text{exp}(X).\n", "\\end{align}\n", "\n", "\n", "You may have recognized that this is by definition, $Y\\sim\\text{LogNormal}(0,1)$.\n", "\n", "We begin by importing the relevant libraries:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.distributions.transforms as T\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "smoke_test = ('CI' in os.environ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A variety of bijective transformations live in the [pyro.distributions.transforms](http://docs.pyro.ai/en/stable/distributions.html#transforms) module, and the classes to define transformed distributions live in [pyro.distributions](http://docs.pyro.ai/en/stable/distributions.html). We first create the base distribution of $X$ and the class encapsulating the transform $\\text{exp}(\\cdot)$:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dist_x = dist.Normal(torch.zeros(1), torch.ones(1))\n", "exp_transform = T.ExpTransform()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The class [ExpTransform](https://pytorch.org/docs/master/distributions.html#torch.distributions.transforms.ExpTransform) derives from [Transform](https://pytorch.org/docs/master/distributions.html#torch.distributions.transforms.Transform) and defines the forward, inverse, and log-absolute-derivative operations for this transform,\n", "\n", "\n", "\\begin{align}\n", "g(x) &= \\text{exp(x)}\\\\\n", "g^{-1}(y) &= \\log(y)\\\\\n", "\\log\\left(\\left|\\frac{dg}{dx}\\right|\\right) &= x.\n", "\\end{align}\n", "\n", "\n", "In general, a transform class defines these three operations, from which it is sufficient to perform sampling and scoring.\n", "\n", "The class [TransformedDistribution](https://pytorch.org/docs/master/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution) takes a base distribution of simple noise and a list of transforms, and encapsulates the distribution formed by applying these transformations in sequence. We use it as:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dist_y = dist.TransformedDistribution(dist_x, [exp_transform])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, plotting samples from both to verify that we that have produced the log-normal distribution:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "