Next, we are going to dive deeper into Storchastic’s API and discuss how to use gradient estimation to train a variational auto-encoder with a discrete latent space on MNIST.

## Discrete Variational Auto-Encoder¶

A variational auto-encoder (VAE) is a popular family of generative deep learning models that use variational inference and gradient estimation [B3]. See [B1] for a thorough mathematical introduction. In this tutorial, we show how to train this model when using a discrete latent space using Storchastic. Let’s first look at the stochastic computation graph of VAEs: The diamond nodes are deterministic nodes that represent cost functions (or “losses”). The corresponding generative story is:

1. Sample a datapoint $$x$$ from the dataset.

2. Compute $$\mathbf{q}=E_\phi(x)$$. This is the logits of the variational distribution encoded from $$x$$.

3. Sample $$\mathbf{z}\sim \operatorname{Cat}(\mathbf{q})$$ from categorical distribution with parameters $$\mathbf{q}$$

4. Compute cost node $$\ell_{KL}(\mathbf{q})$$. This is the KL-divergence of the prior with the variational posterior.

5. Compute $$\hat{x}=D_\theta(\mathbf{z})$$. This decodes from $$\mathbf{z}$$ the reconstructed input $$\hat{x}$$.

6. Compute cost node $$\ell_{rec}(x, \hat{x})$$. This is the ‘reconstruction’ loss.

Because of the sample of $$\mathbf{z}\sim \operatorname{Cat}(\mathbf{q})$$, we need to use gradient estimation. Otherwise, we cannot train the parameters $$\phi$$! Reparameterization is not an option, as we are dealing with discrete random variables. Let’s see how we can solve this using Storchastic.

First, we define the encoder $$E_\phi$$ and decoder $$D_\theta$$ networks. For both the encoder and decoder we use 2 fully connected hidden layers with 512 and 256 hidden units. For the latent space, we will use 2 conditionally independent categorical distributions of 10 choices. This means there are $$10^2=100$$ possible configurations to sum over.

 1import torch
2import torch.nn as nn
3import storch
4from torch.distributions import OneHotCategorical
5
6
7class DiscreteVAE(nn.Module):
8    def __init__(self):
9        super().__init__()
10        self.fc1 = nn.Linear(784, 512)
11        self.fc2 = nn.Linear(512, 256)
12        self.fc3 = nn.Linear(256, 2 * 10)
13        self.fc4 = nn.Linear(2 * 10, 256)
14        self.fc5 = nn.Linear(256, 512)
15        self.fc6 = nn.Linear(512, 784)
16
17    def encode(self, x):
18        h1 = self.fc1(x).relu()
19        h2 = self.fc2(h1).relu()
20        return self.fc3(h2)
21
22    def decode(self, z):
23        h3 = self.fc4(z).relu()
24        h4 = self.fc5(h3).relu()
25        return self.fc6(h4).sigmoid()


In DiscreteVAE.__init__(), we pass the storch.method.Method that we will use to estimate gradients with respect to $$\mathbf{q}$$.

Time to translate our generative story!

28def generative_story(method: storch.method.Method, model: DiscreteVAE, data: torch.Tensor):
29    x = storch.denote_independent(data.view(-1, 784), 0, "data")


data is a tensor containing a minibatch of MNIST images of shape (minibatch,28,28). As we mentioned in our generative story, we sample a datapoint $$x$$ from the dataset. We thus have to tell Storchastic that the first minibatch dimension is an independent* dimension! We give this dimension the plate name “data”.

30    # Encode data. Shape: (data, 2 * 10)
31    q_logits = model.encode(x)
32    # Shape: (data, 2, 10)
33    q_logits = q_logits.reshape(-1, 2, 10)
34    # Define variational posterior
35    q = OneHotCategorical(probs=q_logits.softmax(dim=-1))
36    # Sample from variational posterior. Shape: (amt_samples, data, 2, 10)
37    z = method(q)


Here, we define a one-hot categorical distribution based on the logits from the encoder. Using the passed storch.method.Method, we sample from this distribution to get z. We have to reshape the logits and sample to properly denote that we want 2 conditionally independent categorical latent variables with 10 choices, instead of 1 categorical latent variable with 20 choices.

The KL-divergence loss $$\ell_{KL}(\mathbf{q})$$ can be computed using

37    prior = OneHotCategorical(probs=torch.ones_like(q.probs) / 10.0)
38    # Shape: (data)
39    KL_div = torch.distributions.kl_divergence(q, prior).sum(-1)


We define an uniform prior over the categorical random variables, and then use torch.distributions.kl_divergence() to analytically compute the KL-divergence between this prior and the variational posterior we found. We want to minimize this KL-divergence, so we use storch.add_cost() to register this node.

Next, we reconstruct $$\hat{x}$$ from $$z$$, and compute the reconstruction loss:

41    z_in = z.reshape(z.shape[:-2] + (2 * 10,))
42    # Shape: (amt_samples, data, 784)
43    reconstruction = model.decode(z_in)
44    bce = torch.nn.BCELoss(reduction="none")(reconstruction, x).sum(-1)
46
47    return z


Here we use our model to decode $$z$$, then compute the binary cross entropy between the reconstruction and the original datapoint. The computation of the binary cross entropy is a bit subtle. We first pass “none” to reduction to denote that we do not want to sum over the result, yet then we still sum over the last dimension afterwards. Why not:

>>> torch.nn.BCELoss(reduction="sum")(reconstruction, x)
ValueError: Got an input tensor with too few dimensions. We expected 2 plate dimensions. Instead, we found only 0 dimensions. Violated at dimension 0


This error means that we have removed a plate dimension. Setting :python:reduction=”sum” makes the loss function return only a single number. In Storchastic, it is not allowed to remove dimensions that are denoted as independent unless the user explicitly asks Storchastic to do so. We can also, for example, not do the following:

>>> torch.mean(bce)
ValueError: Got an input tensor with too few dimensions. We expected 2 plate dimensions. Instead, we found only 0 dimensions. Violated at dimension 0


Why would it not be allowed here, as we are just computing our loss function? We would average over our samples anyways? Storchastic is no longer able to compute gradient estimates after one would take the mean. For example, if we use the score function and we take multiple samples, we would need to multiply the log probability of the samples with the corresponding computed loss. This happens during inference in storch.backward(). If we would have taken the mean, we could no longer recover the individual loss outputs!

To make life easier, Storchastic is designed with “fail-quick” in mind. Therefore, if code is written that is likely to result in such errors, it will crash!

Next, we load the MNIST dataset 1:

46from torchvision import datasets, transforms
47
49    datasets.MNIST(
51    ),
52    batch_size=64,
53    shuffle=True,
54)


Finally, we put everything together in the training loop and add a training evaluation that also computes gradient variance:

53def train(method: storch.method.Method, train_loader):
54    model = DiscreteVAE()
56    for epoch in range(5):
57        print("Epoch:" + str(epoch + 1))
58        for i, (data, _) in enumerate(train_loader):
60
61            generative_story(method, model, data)
62            storch.backward()
63            optimizer.step()
64            if i % 300 == 0:
65                evaluate(method, model, data, optimizer)
66
67
68def evaluate(method: storch.method.Method, model: DiscreteVAE, data, optimizer):
70    for i in range(100):
72
73        z = generative_story(method, model, data)
74        elbo = storch.backward()
77
78    print(
79        "Training ELBO "
80        + str(elbo.item())
81        + ". Gradient variance "
83    )


We use an optimizer as normal, however, we call storch.backward() to compute the gradients. To get the gradient for the gradient variance computation, we use storch.StochasticTensor.param_grads. In this example, we will do 5 training epochs.

## Experimenting with the Discrete VAE¶

Let us naively try with the score function, no baseline and a single sample:

>>> train(storch.method.ScoreFunction("z", n_samples=1, baseline_factory="None"))
Epoch:1
Training ELBO 543.1165771484375. Gradient variance 12860.05078125
Training ELBO 192.7163848876953. Gradient variance 1841.5672607421875
Training ELBO 196.0297088623047. Gradient variance 146922.4375
Training ELBO 203.4992218017578. Gradient variance 277160.9375
Epoch:2
Training ELBO 191.17823791503906. Gradient variance 28171.1796875
Training ELBO 193.627685546875. Gradient variance 130113.6953125
Training ELBO 200.20506286621094. Gradient variance 210252.90625
Training ELBO 197.44796752929688. Gradient variance 523125.375
Epoch:3
Training ELBO 202.95068359375. Gradient variance 39035.52734375
Training ELBO 195.01597595214844. Gradient variance 104070.5078125
Training ELBO 198.34580993652344. Gradient variance 7388.16845703125
Training ELBO 192.8509979248047. Gradient variance 457.5860595703125
Epoch:4
Training ELBO 184.8743896484375. Gradient variance 1029.0767822265625
Training ELBO 201.11277770996094. Gradient variance 168409.390625
Training ELBO 199.73178100585938. Gradient variance 13173.759765625
Training ELBO 198.03358459472656. Gradient variance 4439.3017578125
Epoch:5
Training ELBO 194.9002685546875. Gradient variance 18869.560546875
Training ELBO 188.87930297851562. Gradient variance 45681.5546875
Training ELBO 209.60585021972656. Gradient variance 5402.30615234375
Training ELBO 190.37799072265625. Gradient variance 34421.51953125


It clearly is training, but the training ELBO seems rather unstable. Let’s again use the batch average baseline and 10 samples:

>>> train(storch.method.ScoreFunction("z", n_samples=10, baseline_factory="batch_average"), train_loader)
Epoch:1
Training ELBO 543.8447265625. Gradient variance 0.00031363777816295624
Training ELBO 195.42037963867188. Gradient variance 4.808237075805664
Training ELBO 176.24708557128906. Gradient variance 20.694387435913086
Training ELBO 175.12139892578125. Gradient variance 114.5234375
Epoch:2
Training ELBO 161.55543518066406. Gradient variance 117.06814575195312
Training ELBO 163.81021118164062. Gradient variance 764.1204223632812
Training ELBO 167.0965576171875. Gradient variance 0.7521735429763794
Training ELBO 163.2193145751953. Gradient variance 1854.9320068359375
Epoch:3
Training ELBO 165.54237365722656. Gradient variance 34.32332229614258
Training ELBO 159.99176025390625. Gradient variance 121.91394805908203
Training ELBO 149.61558532714844. Gradient variance 4.967251777648926
Training ELBO 165.6819305419922. Gradient variance 2.564244270324707
Epoch:4
Training ELBO 156.40789794921875. Gradient variance 215.02999877929688
Training ELBO 152.97520446777344. Gradient variance 258.04400634765625
Training ELBO 157.0828094482422. Gradient variance 13.990401268005371
Training ELBO 157.7599639892578. Gradient variance 1.4151099920272827
Epoch:5
Training ELBO 164.08978271484375. Gradient variance 391.89794921875
Training ELBO 156.1527862548828. Gradient variance 2.9808785915374756
Training ELBO 154.16932678222656. Gradient variance 10.244932174682617
Training ELBO 151.02488708496094. Gradient variance 4701.791015625


Much better! Our variance reduction techniques help reduce the gradient variance by several factors, which results in far lower training ELBO scores.

Another popular technique is the Gumbel-softmax-trick [B2][B4]. This trick uses a continuous that approximates the One Hot Categorical distribution. This distribution allows reparameterization. Because the decoder of the Discrete VAE does not require its inputs to be discrete, we can apply this trick here!

The Gumbel-softmax trick is a biased gradient estimation tool. This means that it is not an estimate of the correct gradient. Using storch.method.Expect, we can estimate just how biased it is. Let’s edit our evaluation function:

def evaluate(method: storch.method.Method, model: DiscreteVAE, data, optimizer):

z = generative_story(storch.method.Expect("z"), model, data)
storch.backward()

for i in range(100):

z = generative_story(method, model, data)
elbo = storch.backward()

).sum()
print(
"Training ELBO "
+ str(elbo.item())
)


storch.method.Expect is not a gradient estimation method, but computes the gradient analytically by summing over all options. Because we have a small latent space of $$10^2=100$$ options, this is viable.

>>> train(storch.method.GumbelSoftmax("z", n_samples=1), train_loader)
Epoch:1
Epoch:2
Epoch:3
Epoch:4
Epoch:5


Oof, that is not great! The gumbel softmax does even worse than the score function without variance reduction. Theoretically, using stochastic optimization with gradient estimation only for unbiased gradient estimation methods. We should note that the gumbel-softmax performs much better for larger latent spaces, for example when using 20 categorical latent variables of 10 options.

We could also just use storch.method.Expect to train the model:

Epoch:1
Training ELBO 543.6659545898438
Training ELBO 175.1640625
Training ELBO 163.3818359375
Training ELBO 158.3362274169922
Epoch:2
Training ELBO 159.03167724609375
Training ELBO 158.54054260253906
Training ELBO 151.9814453125
Training ELBO 162.34519958496094
Epoch:3
Training ELBO 154.2731475830078
Training ELBO 159.92709350585938
Training ELBO 157.92642211914062
Training ELBO 147.97755432128906
Epoch:4
Training ELBO 151.23654174804688
Training ELBO 155.57571411132812
Training ELBO 142.53665161132812
Training ELBO 141.1732635498047
Epoch:5
Training ELBO 152.55979919433594
Training ELBO 154.68777465820312
Training ELBO 151.78952026367188
Training ELBO 156.02206420898438


### Footnotes¶

1

Note that it is best practice to use the binarized MNIST dataset as proposed by http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf.

2

It is very important to note that the Gumbel-softmax trick only applies to differentiable functions of continuous inputs. Using the straight-through variant (storch.method.GumbelSoftmax("z", straight_through=True)`), we can use it for differentiable functions of discrete inputs, however.