The Semi-Supervised VAE¶
Introduction¶
Most of the models we’ve covered in the tutorials are unsupervised:
We’ve also covered a simple supervised model:
The semi-supervised setting represents an interesting intermediate case where some of the data is labeled and some is not. It is also of great practical importance, since we often have very little labeled data and much more unlabeled data. We’d clearly like to leverage labeled data to improve our models of the unlabeled data.
The semi-supervised setting is also well suited to generative models, where missing data can be accounted for quite naturally—at least conceptually. As we will see, in restricting our attention to semi-supervised generative models, there will be no shortage of different model variants and possible inference strategies. Although we’ll only be able to explore a few of these variants in detail, hopefully you will come away from the tutorial with a greater appreciation for the abstractions and modularity offered by probabilistic programming.
So let’s go about building a generative model. We have a dataset
where the

For convenience—and since we’re going to model MNIST in our experiments below—let’s suppose the
The Challenges of Inference¶
For concreteness we’re going to continue to assume that the partially-observed
If we apply the general recipe for stochastic variational inference to our model (see SVI Part I) we would be sampling the discrete (and thus non-reparameterizable) variable
whenever it’s unobserved. As discussed in SVI Part III this will generally lead to high-variance gradient estimates.A common way to ameliorate this problem—and one that we’ll explore below—is to forego sampling and instead sum out all ten values of the class label
when we calculate the ELBO for an unlabeled datapoint . This is more expensive per step, but can help us reduce the variance of our gradient estimator and thereby take fewer steps.Recall that the role of the guide is to ‘fill in’ latent random variables. Concretely, one component of our guide will be a digit classifier
that will randomly ‘fill in’ labels given an image . Crucially, this means that the only term in the ELBO that will depend on is the term that involves a sum over unlabeled datapoints. This means that our classifier —which in many cases will be the primary object of interest—will not be learning from the labeled datapoints (at least not directly).This seems like a potential problem. Luckily, various fixes are possible. Below we’ll follow the approach in reference [1], which involves introducing an additional objective function for the classifier to ensure that the classifier learns directly from the labeled data.
We have our work cut out for us so let’s get started!
First Variant: Standard objective function, naive estimator¶
As discussed in the introduction, we’re considering the model depicted in Figure 1. In more detail, the model has the following structure:
: multinomial (or categorical) prior for the class label : unit normal prior for the latent code : parameterized Bernoulli likelihood function; corresponds todecoder
in the code
We structure the components of our guide
: parameterized multinomial (or categorical) distribution; corresponds toencoder_y
in the code : parameterized normal distribution; and correspond to the neural digit classifierencoder_z
in the code
These choices reproduce the structure of model M2 and its corresponding inference network in reference [1].
We translate this model and guide pair into Pyro code below. Note that:
The labels
ys
, which are represented with a one-hot encoding, are only partially observed (None
denotes unobserved values).model()
handles both the observed and unobserved case.The code assumes that
xs
andys
are mini-batches of images and labels, respectively, with the size of each batch denoted bybatch_size
.
[ ]:
def model(self, xs, ys=None):
# register this pytorch module and all of its sub-modules with pyro
pyro.module("ss_vae", self)
batch_size = xs.size(0)
# inform Pyro that the variables in the batch of xs, ys are conditionally independent
with pyro.plate("data"):
# sample the handwriting style from the constant prior distribution
prior_loc = xs.new_zeros([batch_size, self.z_dim])
prior_scale = xs.new_ones([batch_size, self.z_dim])
zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
# if the label y (which digit to write) is supervised, sample from the
# constant prior, otherwise, observe the value (i.e. score it against the constant prior)
alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)
# finally, score the image (x) using the handwriting style (z) and
# the class label y (which digit to write) against the
# parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
# where `decoder` is a neural network
loc = self.decoder([zs, ys])
pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
def guide(self, xs, ys=None):
with pyro.plate("data"):
# if the class label (the digit) is not supervised, sample
# (and score) the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
if ys is None:
alpha = self.encoder_y(xs)
ys = pyro.sample("y", dist.OneHotCategorical(alpha))
# sample (and score) the latent handwriting-style with the variational
# distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
loc, scale = self.encoder_z([xs, ys])
pyro.sample("z", dist.Normal(loc, scale).to_event(1))
Network Definitions¶
In our experiments we use the same network configurations as used in reference [1]. The encoder and decoder networks have one hidden layer with encoder_y
, sigmoid as the output activation function for decoder
and exponentiation for the scale part of the output of encoder_z
. The latent dimension is 50.
MNIST Pre-Processing¶
We normalize the pixel values to the range torchvision
library. The testing set consists of
The Objective Function¶
The objective function for this model has the two terms (c.f. Eqn. 8 in reference [1]):
To implement this in Pyro, we setup a single instance of the SVI
class. The two different terms in the objective functions will emerge automatically depending on whether we pass the step
method labeled or unlabeled data. We will alternate taking steps with labeled and unlabeled mini-batches, with the number of steps taken for each type of mini-batch depending on the total fraction of data that is labeled. For example, if we have 1,000 labeled images and 49,000 unlabeled ones, then we’ll
take 49 steps with unlabeled mini-batches for each labeled mini-batch. (Note that there are different ways we could do this, but for simplicity we only consider this variant.) The code for this setup is given below:
[ ]:
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
# setup the optimizer
adam_params = {"lr": 0.0003}
optimizer = Adam(adam_params)
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
When we run this inference in Pyro, the performance seen during test time is degraded by the noise inherent in the sampling of the categorical variables (see Figure 2 and Table 1 at the end of this tutorial). To deal with this we’re going to need a better ELBO gradient estimator.
![]() |
![]() |
Interlude: Summing Out Discrete Latents¶
As highlighted in the introduction, when the discrete latent labels
with an explicit sum
This sum is usually implemented by hand, as in [1], but Pyro can automate this in many cases. To automatically sum out all discrete latent variables (here only config_enumerate()
:
svi = SVI(model, config_enumerate(guide), optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1))
In this mode of operation, each svi.step(...)
computes a gradient term for each of the ten latent states of
Going beyond the particular model in this tutorial, Pyro supports summing over arbitrarily many discrete latent variables. Beware that the cost of summing is exponential in the number of discrete variables, but is cheap(er) if multiple independent discrete variables are packed into a single tensor (as in this tutorial, where the discrete labels for the entire mini-batch are packed into the single tensor config_enumerate()
, we must inform Pyro
that the items in a minibatch are indeed independent by wrapping our vectorized code in a with pyro.plate("name")
block.
Second Variant: Standard Objective Function, Better Estimator¶
Now that we have the tools to sum out discrete latents, we can see if doing so helps our performance. First, as we can see from Figure 3, the test and validation accuracies now evolve much more smoothly over the course of training. More importantly, this single modification improved test accuracy from around 20%
to about 90%
for the case of
![]() |
![]() |
Third Variant: Adding a Term to the Objective¶
For the two variants we’ve explored so far, the classifier
where
To learn using this modified objective in Pyro we do the following:
We use a new model and guide pair (see the code snippet below) that corresponds to scoring the observed label
for a given image against the predictive distributionWe specify the scaling factor
(aux_loss_multiplier
in the code) in thepyro.sample
call by making use ofpoutine.scale
. Note thatpoutine.scale
was used to similar effect in the Deep Markov Model to implement KL annealing.We create a new
SVI
object and use it to take gradient steps on the new objective term
[ ]:
def model_classify(self, xs, ys=None):
pyro.module("ss_vae", self)
with pyro.plate("data"):
# this here is the extra term to yield an auxiliary loss
# that we do gradient descent on
if ys is not None:
alpha = self.encoder_y(xs)
with pyro.poutine.scale(scale=self.aux_loss_multiplier):
pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def guide_classify(xs, ys):
# the guide is trivial, since there are no
# latent random variables
pass
svi_aux = SVI(model_classify, guide_classify, optimizer, loss=Trace_ELBO())
When we run inference in Pyro with the additional term in the objective, we outperform both previous inference setups. For example, the test accuracy for the case with 90%
to 96%
(see Figure 4 below and Table 1 in the next section). Note that we used validation accuracy to select the hyperparameter
![]() |
![]() |
Results¶
Supervised data |
First variant |
Second variant |
Third variant |
Baseline classifier |
---|---|---|---|---|
100 |
0.2007(0.0353) |
0.2254(0.0346) |
0.9319(0.0060) |
0.7712(0.0159) |
600 |
0.1791(0.0244) |
0.6939(0.0345) |
0.9437(0.0070) |
0.8716(0.0064) |
1000 |
0.2006(0.0295) |
0.7562(0.0235) |
0.9487(0.0038) |
0.8863(0.0025) |
3000 |
0.1982(0.0522) |
0.8932(0.0159) |
0.9582(0.0012) |
0.9108(0.0015) |
Table 1: Result accuracies (with 95% confidence bounds) for different inference methods
Table 1 collects our results from the three variants explored in the tutorial. For comparison, we also show results from a simple classifier baseline, which only makes use of the supervised data (and no latent random variables). Reported are mean accuracies (with 95% confidence bounds in parentheses) across five random selections of supervised data.
We first note that the results for the third variant—where we summed out the discrete latent random variable
Latent Space Visualization¶
![]() |
We use T-SNE to reduce the dimensionality of the latent
Conditional image generation¶
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
We sampled
Final thoughts¶
We’ve seen that generative models offer a natural approach to semi-supervised machine learning. One of the most attractive features of generative models is that we can explore a large variety of models in a single unified setting. In this tutorial we’ve only been able to explore a small fraction of the possible model and inference setups that are possible. There is no reason to expect that one variant is best; depending on the dataset and application, there will be reason to prefer one over another. And there are a lot of variants (see Figure 7)!

Some of these variants clearly make more sense than others, but a priori it’s difficult to know which ones are worth trying out. This is especially true once we open the door to more complicated setups, like the two models at the bottom of the figure, which include an always latent random variable
The reader probably doesn’t need any convincing that a systematic exploration of even a fraction of these options would be incredibly time-consuming and error-prone if each model and each inference procedure were coded up by scratch. It’s only with the modularity and abstraction made possible by a probabilistic programming system that we can hope to explore the landscape of generative models with any kind of nimbleness—and reap any awaiting rewards.
See the full code on Github.
References¶
[1] Semi-supervised Learning with Deep Generative Models
, Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling
[2] Learning Disentangled Representations with Semi-Supervised Deep Generative Models
, N. Siddharth, Brooks Paige, Jan-Willem Van de Meent, Alban Desmaison, Frank Wood, Noah D. Goodman, Pushmeet Kohli, Philip H.S. Torr