{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Variational Autoencoders\n", "\n", "## Introduction\n", "\n", "The variational autoencoder (VAE) is arguably the simplest setup that realizes deep probabilistic modeling. Note that we're being careful in our choice of language here. The VAE isn't a model as such—rather the VAE is a particular setup for doing variational inference for a certain class of models. The class of models is quite broad: basically\n", "any (unsupervised) density estimator with latent random variables. The basic structure of such a model is simple, almost deceptively so (see Fig. 1)." ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
Figure 1: the class of deep models we're interested in.

" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Here we've depicted the structure of the kind of model we're interested in as a graphical model. We have $N$ observed datapoints $\\{ \\bf x_i \\}$. Each datapoint is generated by a (local) latent random variable $\\bf z_i$. There is also a parameter $\\theta$, which is global in the sense that all the datapoints depend on it (which is why it's drawn outside the rectangle). Note that since $\\theta$ is a parameter, it's not something we're being Bayesian about. Finally, what's of particular importance here is that we allow for each $\\bf x_i$ to depend on $\\bf z_i$ in a complex, non-linear way. In practice this dependency will be parameterized by a (deep) neural network with parameters $\\theta$. It's this non-linearity that makes inference for this class of models particularly challenging. \n", "\n", "Of course this non-linear structure is also one reason why this class of models offers a very flexible approach to modeling complex data. Indeed it's worth emphasizing that each of the components of the model can be 'reconfigured' in a variety of different ways. For example:\n", "\n", "- the neural network in $p_\\theta({\\bf x} | {\\bf z})$ can be varied in all the usual ways (number of layers, type of non-linearities, number of hidden units, etc.)\n", "- we can choose observation likelihoods that suit the dataset at hand: gaussian, bernoulli, categorical, etc.\n", "- we can choose the number of dimensions in the latent space\n", "\n", "The graphical model representation is a useful way to think about the structure of the model, but it can also be fruitful to look at an explicit factorization of the joint probability density:\n", "\n", "$$ p({\\bf x}, {\\bf z}) = \\prod_{i=1}^N p_\\theta({\\bf x}_i | {\\bf z}_i) p({\\bf z}_i) $$\n", "\n", "The fact that $p({\\bf x}, {\\bf z})$ breaks up into a product of terms like this makes it clear what we mean when we call $\\bf z_i$ a local random variable. For any particular $i$, only the single datapoint $\\bf x_i$ depends on $\\bf z_i$. As such the $\\{\\bf z_i\\}$ describe local structure, i.e. structure that is private to each data point. This factorized structure also means that we can do subsampling during the course of learning. As such this sort of model is amenable to the large data setting. (For more discussion on this and related topics see [SVI Part II](svi_part_ii.ipynb).)\n", "\n", "That's all there is to the model. Since the observations depend on the latent random variables in a complicated, non-linear way, we expect the posterior over the latents to have a complex structure. Consequently in order to do inference in this model we need to specify a flexibly family of guides (i.e. variational distributions). Since we want to be able to scale to large datasets, our guide is going to make use of amortization to keep the number of variational parameters under control (see [SVI Part II](svi_part_ii.ipynb) for a somewhat more general discussion of amortization). \n", "\n", "Recall that the job of the guide is to 'guess' good values for the latent random variables—good in the sense that they're true to the model prior _and_ true to the data. If we weren't making use of amortization, we would introduce variational parameters \n", "$\\{ \\lambda_i \\}$ for _each_ datapoint $\\bf x_i$. These variational parameters would represent our belief about 'good' values of $\\bf z_i$; for example, they could encode the mean and variance of a gaussian distribution in ${\\bf z}_i$ space. Amortization means that, rather than introducing variational parameters $\\{ \\lambda_i \\}$, we instead learn a _function_ that maps each $\\bf x_i$ to an appropriate $\\lambda_i$. Since we need this function to be flexible, we parameterize it as a neural network. We thus end up with a parameterized family of distributions over the latent $\\bf z$ space that can be instantiated for all $N$ datapoint ${\\bf x}_i$ (see Fig. 2)." ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
Figure 2: a graphical representation of the guide.

" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the guide $q_{\\phi}({\\bf z} | {\\bf x})$ is parameterized by a global parameter $\\phi$ shared by all the datapoints. The goal of inference will be to find 'good' values for $\\theta$ and $\\phi$ so that two conditions are satisfied:\n", "\n", "- the log evidence $\\log p_\\theta({\\bf x})$ is large. this means our model is a good fit to the data\n", "- the guide $q_{\\phi}({\\bf z} | {\\bf x})$ provides a good approximation to the posterior \n", "\n", "(For an introduction to stochastic variational inference see [SVI Part I](svi_part_i.ipynb).)\n", "\n", "At this point we can zoom out and consider the high level structure of our setup. For concreteness, let's suppose the $\\{ \\bf x_i \\}$ are images so that the model is a generative model of images. Once we've learned a good value of $\\theta$ we can generate images from the model as follows:\n", "\n", "- sample $\\bf z$ according to the prior $p({\\bf z})$\n", "- sample $\\bf x$ according to the likelihood $p_\\theta({\\bf x}|{\\bf z})$\n", "\n", "Each image is being represented by a latent code $\\bf z$ and that code gets mapped to images using the likelihood, which depends on the $\\theta$ we've learned. This is why the likelihood is often called the decoder in this context: its job is to decode $\\bf z$ into $\\bf x$. Note that since this is a probabilistic model, there is uncertainty about the $\\bf z$ that encodes a given datapoint $\\bf x$.\n", "\n", "Once we've learned good values for $\\theta$ and $\\phi$ we can also go through the following exercise. \n", "\n", "- we start with a given image $\\bf x$\n", "- using our guide we encode it as $\\bf z$\n", "- using the model likelihood we decode $\\bf z$ and get a reconstructed image ${\\bf x}_{\\rm reco}$\n", "\n", "If we've learned good values for $\\theta$ and $\\phi$, $\\bf x$ and ${\\bf x}_{\\rm reco}$ should be similar. This should clarify how the word autoencoder ended up being used to describe this setup: the model is the decoder and the guide is the encoder. Together, they can be thought of as an autoencoder.\n", "\n", "## VAE in Pyro\n", "\n", "Let's see how we implement a VAE in Pyro.\n", "The dataset we're going to model is MNIST, a collection of images of handwritten digits.\n", "Since this is a popular benchmark dataset, we can make use of PyTorch's convenient data loader functionalities to reduce the amount of boilerplate code we need to write:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import numpy as np\n", "import torch\n", "from pyro.contrib.examples.util import MNIST\n", "import torch.nn as nn\n", "import torchvision.transforms as transforms\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.contrib.examples.util # patches torchvision\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.optim import Adam" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert pyro.__version__.startswith('1.9.0')\n", "pyro.distributions.enable_validation(False)\n", "pyro.set_rng_seed(0)\n", "# Enable smoke test - run the notebook cells on CI.\n", "smoke_test = 'CI' in os.environ " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# for loading and batching MNIST dataset\n", "def setup_data_loaders(batch_size=128, use_cuda=False):\n", " root = './data'\n", " download = True\n", " trans = transforms.ToTensor()\n", " train_set = MNIST(root=root, train=True, transform=trans,\n", " download=download)\n", " test_set = MNIST(root=root, train=False, transform=trans)\n", "\n", " kwargs = {'num_workers': 1, 'pin_memory': use_cuda}\n", " train_loader = torch.utils.data.DataLoader(dataset=train_set,\n", " batch_size=batch_size, shuffle=True, **kwargs)\n", " test_loader = torch.utils.data.DataLoader(dataset=test_set,\n", " batch_size=batch_size, shuffle=False, **kwargs)\n", " return train_loader, test_loader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main thing to draw attention to here is that we use `transforms.ToTensor()` to normalize the pixel intensities to the range $[0.0, 1.0]$. \n", "\n", "Next we define a PyTorch module that encapsulates our decoder network:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Decoder(nn.Module):\n", " def __init__(self, z_dim, hidden_dim):\n", " super().__init__()\n", " # setup the two linear transformations used\n", " self.fc1 = nn.Linear(z_dim, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, 784)\n", " # setup the non-linearities\n", " self.softplus = nn.Softplus()\n", " self.sigmoid = nn.Sigmoid()\n", "\n", " def forward(self, z):\n", " # define the forward computation on the latent z\n", " # first compute the hidden units\n", " hidden = self.softplus(self.fc1(z))\n", " # return the parameter for the output Bernoulli\n", " # each is of size batch_size x 784\n", " loc_img = self.sigmoid(self.fc21(hidden))\n", " return loc_img" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a latent code $z$, the forward call of `Decoder` returns the parameters for a Bernoulli distribution in image space. Since each image is of size\n", "$28\\times28=784$, `loc_img` is of size `batch_size` x 784.\n", "\n", "Next we define a PyTorch module that encapsulates our encoder network:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Encoder(nn.Module):\n", " def __init__(self, z_dim, hidden_dim):\n", " super().__init__()\n", " # setup the three linear transformations used\n", " self.fc1 = nn.Linear(784, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, z_dim)\n", " self.fc22 = nn.Linear(hidden_dim, z_dim)\n", " # setup the non-linearities\n", " self.softplus = nn.Softplus()\n", "\n", " def forward(self, x):\n", " # define the forward computation on the image x\n", " # first shape the mini-batch to have pixels in the rightmost dimension\n", " x = x.reshape(-1, 784)\n", " # then compute the hidden units\n", " hidden = self.softplus(self.fc1(x))\n", " # then return a mean vector and a (positive) square root covariance\n", " # each of size batch_size x z_dim\n", " z_loc = self.fc21(hidden)\n", " z_scale = torch.exp(self.fc22(hidden))\n", " return z_loc, z_scale" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given an image $\\bf x$ the forward call of `Encoder` returns a mean and covariance that together parameterize a (diagonal) Gaussian distribution in latent space.\n", "\n", "With our encoder and decoder networks in hand, we can now write down the stochastic functions that represent our model and guide. First the model: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define the model p(x|z)p(z)\n", "def model(self, x):\n", " # register PyTorch module `decoder` with Pyro\n", " pyro.module(\"decoder\", self.decoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # setup hyperparameters for prior p(z)\n", " z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))\n", " z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))\n", " # sample from prior (value will be sampled by guide when computing the ELBO)\n", " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", " # decode the latent code z\n", " loc_img = self.decoder(z)\n", " # score against actual images\n", " pyro.sample(\"obs\", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `model()` is a callable that takes in a mini-batch of images `x` as input. This is a `torch.Tensor` of size `batch_size` x 784.\n", "\n", "The first thing we do inside of `model()` is register the (previously instantiated) decoder module with Pyro. Note that we give it an appropriate (and unique) name. This call to `pyro.module` lets Pyro know about all the parameters inside of the decoder network. \n", "\n", "Next we setup the hyperparameters for our prior, which is just a unit normal gaussian distribution. Note that:\n", "- we specifically designate independence amongst the data in our mini-batch (i.e. the leftmost dimension) via `pyro.plate`. Also, note the use of `.to_event(1)` when sampling from the latent `z` - this ensures that instead of treating our sample as being generated from a univariate normal with `batch_size = z_dim`, we treat them as being generated from a multivariate normal distribution with diagonal covariance. As such, the log probabilities along each dimension is summed out when we evaluate `.log_prob` for a \"latent\" sample. Refer to the [Tensor Shapes](tensor_shapes.ipynb) tutorial for more details.\n", "- since we're processing an entire mini-batch of images, we need the leftmost dimension of `z_loc` and `z_scale` to equal the mini-batch size\n", "- in case we're on GPU, we use `new_zeros` and `new_ones` to ensure that newly created tensors are on the same GPU device.\n", "\n", "Next we sample the latent `z` from the prior, making sure to give the random variable a unique Pyro name `'latent'`. \n", "Then we pass `z` through the decoder network, which returns `loc_img`. We then score the observed images in the mini-batch `x` against the Bernoulli likelihood parametrized by `loc_img`.\n", "Note that we flatten `x` so that all the pixels are in the rightmost dimension.\n", "\n", "That's all there is to it! Note how closely the flow of Pyro primitives in `model` follows the generative story of our model, e.g. as encapsulated by Figure 1. Now we move on to the guide:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define the guide (i.e. variational distribution) q(z|x)\n", "def guide(self, x):\n", " # register PyTorch module `encoder` with Pyro\n", " pyro.module(\"encoder\", self.encoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # use the encoder to get the parameters used to define q(z|x)\n", " z_loc, z_scale = self.encoder(x)\n", " # sample the latent code z\n", " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just like in the model, we first register the PyTorch module we're using (namely `encoder`) with Pyro. We take the mini-batch of images `x` and pass it through the encoder. Then using the parameters output by the encoder network we use the normal distribution to sample a value of the latent for each image in the mini-batch. Crucially, we use the same name for the latent random variable as we did in the model: `'latent'`. Also, note the use of `pyro.plate` to designate independence of the mini-batch dimension, and `.to_event(1)` to enforce dependence on `z_dims`, exactly as we did in the model.\n", "\n", "Now that we've defined the full model and guide we can move on to inference. But before we do so let's see how we package the model and guide in a PyTorch module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class VAE(nn.Module):\n", " # by default our latent space is 50-dimensional\n", " # and we use 400 hidden units\n", " def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):\n", " super().__init__()\n", " # create the encoder and decoder networks\n", " self.encoder = Encoder(z_dim, hidden_dim)\n", " self.decoder = Decoder(z_dim, hidden_dim)\n", "\n", " if use_cuda:\n", " # calling cuda() here will put all the parameters of\n", " # the encoder and decoder networks into gpu memory\n", " self.cuda()\n", " self.use_cuda = use_cuda\n", " self.z_dim = z_dim\n", "\n", " # define the model p(x|z)p(z)\n", " def model(self, x):\n", " # register PyTorch module `decoder` with Pyro\n", " pyro.module(\"decoder\", self.decoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # setup hyperparameters for prior p(z)\n", " z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))\n", " z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))\n", " # sample from prior (value will be sampled by guide when computing the ELBO)\n", " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", " # decode the latent code z\n", " loc_img = self.decoder(z)\n", " # score against actual images\n", " pyro.sample(\"obs\", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))\n", "\n", " # define the guide (i.e. variational distribution) q(z|x)\n", " def guide(self, x):\n", " # register PyTorch module `encoder` with Pyro\n", " pyro.module(\"encoder\", self.encoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " # use the encoder to get the parameters used to define q(z|x)\n", " z_loc, z_scale = self.encoder(x)\n", " # sample the latent code z\n", " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", "\n", " # define a helper function for reconstructing images\n", " def reconstruct_img(self, x):\n", " # encode image x\n", " z_loc, z_scale = self.encoder(x)\n", " # sample in latent space\n", " z = dist.Normal(z_loc, z_scale).sample()\n", " # decode the image (note we don't sample in image space)\n", " loc_img = self.decoder(z)\n", " return loc_img" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The point we'd like to make here is that the two `Module`s `encoder` and `decoder` are attributes of `VAE` (which itself inherits from `nn.Module`). This has the consequence they are both automatically registered as belonging to the `VAE` module. So, for example, when we call `parameters()` on an instance of `VAE`, PyTorch will know to return all the relevant parameters. It also means that if we're running on a GPU, the call to `cuda()` will move all the parameters of all the (sub)modules into GPU memory." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n", "\n", "We're now ready for inference. Refer to the full code in the next section. \n", "\n", "First we instantiate an instance of the `VAE` module." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vae = VAE()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we setup an instance of the Adam optimizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "optimizer = Adam({\"lr\": 1.0e-3})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we setup our inference algorithm, which is going to learn good parameters for the model and guide by maximizing the ELBO:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's all there is to it. Now we just have to define our training loop:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train(svi, train_loader, use_cuda=False):\n", " # initialize loss accumulator\n", " epoch_loss = 0.\n", " # do a training epoch over each mini-batch x returned\n", " # by the data loader\n", " for x, _ in train_loader:\n", " # if on GPU put mini-batch into CUDA memory\n", " if use_cuda:\n", " x = x.cuda()\n", " # do ELBO gradient and accumulate loss\n", " epoch_loss += svi.step(x)\n", "\n", " # return epoch loss\n", " normalizer_train = len(train_loader.dataset)\n", " total_epoch_loss_train = epoch_loss / normalizer_train\n", " return total_epoch_loss_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that all the mini-batch logic is handled by the data loader. The meat of the training loop is `svi.step(x)`. There are two things we should draw attention to here:\n", "\n", "- any arguments to `step` are passed to the model and the guide. consequently `model` and `guide` need to have the same call signature\n", "- `step` returns a noisy estimate of the loss (i.e. minus the ELBO). this estimate is not normalized in any way, so e.g. it scales with the size of the mini-batch\n", "\n", "The logic for adding evaluation logic is analogous:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def evaluate(svi, test_loader, use_cuda=False):\n", " # initialize loss accumulator\n", " test_loss = 0.\n", " # compute the loss over the entire test set\n", " for x, _ in test_loader:\n", " # if on GPU put mini-batch into CUDA memory\n", " if use_cuda:\n", " x = x.cuda()\n", " # compute ELBO estimate and accumulate loss\n", " test_loss += svi.evaluate_loss(x)\n", " normalizer_test = len(test_loader.dataset)\n", " total_epoch_loss_test = test_loss / normalizer_test\n", " return total_epoch_loss_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basically the only change we need to make is that we call evaluate_loss instead of step. This function will compute an estimate of the ELBO but won't take any gradient steps.\n", "\n", "The final piece of code we'd like to highlight is the helper method `reconstruct_img` in the VAE class: This is just the image reconstruction experiment we described in the introduction translated into code. We take an image and pass it through the encoder. Then we sample in latent space using the gaussian distribution provided by the encoder. Finally we decode the latent code into an image: we return the mean vector `loc_img` instead of sampling with it. Note that since the `sample()` statement is stochastic, we'll get different draws of z every time we run the reconstruct_img function. If we've learned a good model and guide—in particular if we've learned a good latent representation—this plurality of z samples will correspond to different styles of digit writing, and the reconstructed images should exhibit an interesting variety of different styles." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Code and Sample results\n", "\n", "Training corresponds to maximizing the evidence lower bound (ELBO) over the training dataset. We train for 100 iterations and evaluate the ELBO for the test dataset, see Figure 3." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run options\n", "LEARNING_RATE = 1.0e-3\n", "USE_CUDA = False\n", "\n", "# Run only for a single iteration for testing\n", "NUM_EPOCHS = 1 if smoke_test else 100\n", "TEST_FREQUENCY = 5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)\n", "\n", "# clear param store\n", "pyro.clear_param_store()\n", "\n", "# setup the VAE\n", "vae = VAE(use_cuda=USE_CUDA)\n", "\n", "# setup the optimizer\n", "adam_args = {\"lr\": LEARNING_RATE}\n", "optimizer = Adam(adam_args)\n", "\n", "# setup the inference algorithm\n", "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n", "\n", "train_elbo = []\n", "test_elbo = []\n", "# training loop\n", "for epoch in range(NUM_EPOCHS):\n", " total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)\n", " train_elbo.append(-total_epoch_loss_train)\n", " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n", "\n", " if epoch % TEST_FREQUENCY == 0:\n", " # report test diagnostics\n", " total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)\n", " test_elbo.append(-total_epoch_loss_test)\n", " print(\"[epoch %03d] average test loss: %.4f\" % (epoch, total_epoch_loss_test))" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
\n", "
\n", "\n", "
\n", "Figure 3: How the test ELBO evolves over the course of training. \n", "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we show a set of randomly sampled images from the model. These are generated by drawing random samples of `z` and generating an image for each one, see Figure 4." ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", "
\n", "
\n", " Figure 4: Samples from generative model.\n", "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also study the 50-dimensional latent space of the entire test dataset by encoding all MNIST images and embedding their means into a 2-dimensional T-SNE space. We then color each embedded image by its class.\n", "The resulting Figure 5 shows separation by class with variance within each class-cluster." ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/html" }, "source": [ "
\n", "
\n", "\n", "
\n", "Figure 5: T-SNE Embedding of the latent z. The colors correspond to different classes of digits.\n", "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the full code on [Github](https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py).\n", "\n", "## References\n", "\n", "[1] `Auto-Encoding Variational Bayes`,
    \n", "Diederik P Kingma, Max Welling\n", "\n", "[2] `Stochastic Backpropagation and Approximate Inference in Deep Generative Models`,\n", "
    \n", "Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra" ] } ], "metadata": { "anaconda-cloud": {}, "celltoolbar": "Raw Cell Format", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 2 }