{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Gaussian Process Latent Variable Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [Gaussian Process Latent Variable Model](https://en.wikipedia.org/wiki/Nonlinear_dimensionality_reduction#Gaussian_process_latent_variable_models) (GPLVM) is a dimensionality reduction method that uses a Gaussian process to learn a low-dimensional representation of (potentially) high-dimensional data. In the typical setting of Gaussian process regression, where we are given inputs $X$ and outputs $y$, we choose a kernel and learn hyperparameters that best describe the mapping from $X$ to $y$. In the GPLVM, we are not given $X$: we are only given $y$. So we need to learn $X$ along with the kernel hyperparameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We do not do maximum likelihood inference on $X$. Instead, we set a Gaussian prior for $X$ and learn the mean and variance of the approximate (gaussian) posterior $q(X|y)$. In this notebook, we show how this can be done using the `pyro.contrib.gp` module. In particular we reproduce a result described in [2]." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch\n", "from torch.nn import Parameter\n", "\n", "import pyro\n", "import pyro.contrib.gp as gp\n", "import pyro.distributions as dist\n", "import pyro.ops.stats as stats\n", "\n", "smoke_test = ('CI' in os.environ) # ignore; used to check code integrity in the Pyro repo\n", "assert pyro.__version__.startswith('1.9.0')\n", "pyro.set_rng_seed(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data we are going to use consists of [single-cell](https://en.wikipedia.org/wiki/Single-cell_analysis) [qPCR](https://en.wikipedia.org/wiki/Real-time_polymerase_chain_reaction) data for 48 genes obtained from mice (Guo *et al.*, [1]). This data is available at the [Open Data Science repository](https://github.com/sods/ods). The data contains 48 columns, with each column corresponding to (normalized) measurements of each gene. Cells differentiate during their development and these data were obtained at various stages of development. The various stages are labelled from the 1-cell stage to the 64-cell stage. For the 32-cell stage, the data is further differentiated into 'trophectoderm' (TE) and 'inner cell mass' (ICM). ICM further differentiates into 'epiblast' (EPI) and 'primitive endoderm' (PE) at the 64-cell stage. Each of the rows in the dataset is labelled with one of these stages." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data shape: (437, 48)\n", "---------------------\n", "\n", "Data labels: ['1', '2', '4', '8', '16', '32 TE', '32 ICM', '64 PE', '64 TE', '64 EPI']\n", "--------------------------------------------------------------------------------------\n", "\n", "Show a small subset of the data:\n" ] }, { "data": { "text/html": [ "
\n", " | Actb | \n", "Ahcy | \n", "Aqp3 | \n", "Atp12a | \n", "Bmp4 | \n", "Cdx2 | \n", "Creb312 | \n", "Cebpa | \n", "Dab2 | \n", "DppaI | \n", "... | \n", "Sox2 | \n", "Sall4 | \n", "Sox17 | \n", "Snail | \n", "Sox13 | \n", "Tcfap2a | \n", "Tcfap2c | \n", "Tcf23 | \n", "Utf1 | \n", "Tspan8 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | \n", "0.541050 | \n", "-1.203007 | \n", "1.030746 | \n", "1.064808 | \n", "0.494782 | \n", "-0.167143 | \n", "-1.369092 | \n", "1.083061 | \n", "0.668057 | \n", "-1.553758 | \n", "... | \n", "-1.351757 | \n", "-1.793476 | \n", "0.783185 | \n", "-1.408063 | \n", "-0.031991 | \n", "-0.351257 | \n", "-1.078982 | \n", "0.942981 | \n", "1.348892 | \n", "-1.051999 | \n", "
1 | \n", "0.680832 | \n", "-1.355306 | \n", "2.456375 | \n", "1.234350 | \n", "0.645494 | \n", "1.003868 | \n", "-1.207595 | \n", "1.208023 | \n", "0.800388 | \n", "-1.435306 | \n", "... | \n", "-1.363533 | \n", "-1.782172 | \n", "1.532477 | \n", "-1.361172 | \n", "-0.501715 | \n", "1.082362 | \n", "-0.930112 | \n", "1.064399 | \n", "1.469397 | \n", "-0.996275 | \n", "
1 | \n", "1.056038 | \n", "-1.280447 | \n", "2.046133 | \n", "1.439795 | \n", "0.828121 | \n", "0.983404 | \n", "-1.460032 | \n", "1.359447 | \n", "0.530701 | \n", "-1.340283 | \n", "... | \n", "-1.296802 | \n", "-1.567402 | \n", "3.194157 | \n", "-1.301777 | \n", "-0.445219 | \n", "0.031284 | \n", "-1.005767 | \n", "1.211529 | \n", "1.615421 | \n", "-0.651393 | \n", "
1 | \n", "0.732331 | \n", "-1.326911 | \n", "2.464234 | \n", "1.244323 | \n", "0.654359 | \n", "0.947023 | \n", "-1.265609 | \n", "1.215373 | \n", "0.765212 | \n", "-1.431401 | \n", "... | \n", "-1.684100 | \n", "-1.915556 | \n", "2.962515 | \n", "-1.349710 | \n", "1.875957 | \n", "1.699892 | \n", "-1.059458 | \n", "1.071541 | \n", "1.476485 | \n", "-0.699586 | \n", "
1 | \n", "0.629333 | \n", "-1.244308 | \n", "1.316815 | \n", "1.304162 | \n", "0.707552 | \n", "1.429070 | \n", "-0.895578 | \n", "-0.007785 | \n", "0.644606 | \n", "-1.381937 | \n", "... | \n", "-1.304653 | \n", "-1.761825 | \n", "1.265379 | \n", "-1.320533 | \n", "-0.609864 | \n", "0.413826 | \n", "-0.888624 | \n", "1.114394 | \n", "1.519017 | \n", "-0.798985 | \n", "
5 rows × 48 columns
\n", "