Sampling, Inference and Variance Reduction

Storchastic allows you to define stochastic computation graphs using an API that resembles generative stories. It is designed with plug-and-play in mind: it is very easy to swap in different gradient estimation methods to compare their performance on your task. In this tutorial, we apply gradient estimation to a simple and small problem.

Converting generative stories

We return to the generative story from Stochastic computation graphs:

  1. Compute \(d=a+b\)

  2. Sample \(e\sim \mathcal{N}(c+b, 1)\)

  3. Compute \(f=d\cdot e^2\)

This story is easily converted using the following code:

 1import torch
 2import storch
 3from torch.distributions import Normal
 4from storch.method import Reparameterization, ScoreFunction
 5
 6def compute_f(n):
 7    a = torch.tensor(5.0, requires_grad=True)
 8    b = torch.tensor(-3.0, requires_grad=True)
 9    c = torch.tensor(0.23, requires_grad=True)
10    d = a + b
11
12    # Sample e from a normal distribution using reparameterization
13    normal_distribution = Normal(b + c, 1)
14    method = Reparameterization("e", n_samples=n)
15    e = method(normal_distribution)
16
17    f = d * e * e
18    return f, c

Lines 10 and 17 represent the deterministic nodes. Lines 13-15 represent the stochastic node: We sample a value from the normal distribution using reparameterization (or the pathwise derivative). The storch.method.Reparameterization class is a subclass of storch.method.Method. Subclasses implement functionality for sampling and gradient estimation, and you can subclass storch.method.Method to implement new methods gradient estimation methods. Furthermore, storch.method.Method subclasses torch.nn.Module, which makes it easy for them to become part of a PyTorch model.

storch.method.Reparameterization is initialized with the variable name “e”. This is done to initialize the plate that corresponds to this sample. We will introduce plates later on. Furthermore, they have an optional n_samples option, which controls how many samples are taken from the normal distribution. Note that the method is called directly on the distribution (torch.distributions.Distribution) to sample from.

Gradient estimation

Great. Now how to get the derivative with respect to \(c\)? Storchastic requires you to register cost nodes using storch.add_cost(). These are leave nodes that will be minimized. When all cost nodes are registered, storch.backward() is used to estimate the gradients:

>>> f, c = compute_f(1)
>>> storch.add_cost(f, "f")
>>> storch.backward()
tensor(3.0209, grad_fn=<AddBackward0>)
>>> c.grad
tensor(-4.9160)

The second line registers the cost node with the name “f”, and the third line computes the gradients, where PyTorch’s automatic differentiation is used for deterministic nodes, and Storchastic’s gradient estimation methods for stochastic nodes. storch.backward() returns the estimated value of the sum of cost nodes, which in this case is just \(f\).

We also show the estimated gradient with respect to \(c\) (-4.9160). Note that this gradient is stochastic! Running the code another time, we get -12.2537.

Computing gradient statistics

We can estimate the mean and variance of the gradient as follows:

19n = 1
20gradient_samples = []
21for i in range(1000):
22    f, c = compute_f(n)
23    storch.add_cost(f, "f")
24    storch.backward()
25    gradient_samples.append(c.grad)
26gradients = storch.gather_samples(gradient_samples, "gradients")
>>> storch.variance(gradients, "gradients")
Deterministic tensor(16.7321) Batch links: []
>>> print(storch.reduce_plates(gradients, "gradients"))
Deterministic tensor(-11.0195) Batch links: []

Alright, a few things to note. storch.gather_samples() is a function that takes a list of tensors that are (conditionally) independent samples of some value, in this case the gradients. Like most other methods in Storchastic, it returns a storch.Tensor, in this case a storch.IndependentTensor:

>>> type(gradients)
<class 'storch.tensor.IndependentTensor'>

storch.Tensor is a special “tensor-like” object which wraps a torch.Tensor and includes extra metadata to help with estimating gradients and keeping track of the plate dimensions. Plate dimensions are dimensions of the tensor of which we know conditional independency properties. We can look at the plate dimensions of a storch.Tensor using

>>> gradients.plates
[('gradients', 1000, tensor(0.0010))]

The gradients tensor has one plate dimension with name “gradients” (as we defined using storch.gather_samples()). As we simulated the gradient 1000 times, the size of the plate dimension is 1000. The third value is the weight of the samples. In this case, samples are weighted identically (that is, the weight is 1/1000), which corresponds to a normal monte carlo sample.

Note that we used the plate dimension name “gradients” in storch.variance(gradients, "gradients"). With this we mean that we compute the variance over the gradient plate dimension, which represent the different independent samples of gradient estimates.

Reducing variance

Next, let us try to reduce the variance. A simple way to do this is to use more samples of \(e\). In line 14 (method = Reparameterization("e", n_samples=n), we pass the amount of samples to use for this method. Let’s use 10 by setting line 19 to n = 10, and compute the variance again:

>>> storch.variance(gradients, "gradients")
Deterministic tensor(1.6388) Batch links: []

By using 10 times as many samples, we reduced the variance by (about) a factor 10. Note that we did not have to change any other code but changing the value of n. Storchastic is designed so that all (left-broadcastable!) code supports both using a single or multiple samples. Using more samples is an easy way to reduce variance. Storchastic automatically parallelizes the computation over the different samples, so that if your gpu has enough memory, there is (usually) almost no overhead to using more samples, yet we can get better estimates of the gradient!

Using different estimators

Storchastic is designed to make swapping in different gradient estimation as easy as possible. For instance, say we want to use the score function instead of reparameterization. This is done as follows:

 6def compute_f(n):
 7    a = torch.tensor(5.0, requires_grad=True)
 8    b = torch.tensor(-3.0, requires_grad=True)
 9    c = torch.tensor(0.23, requires_grad=True)
10    d = a + b
11
12    # Sample e from a normal distribution using reparameterization
13    normal_distribution = Normal(b + c, 1)
14    method = ScoreFunction("e", n_samples=n, baseline_factory=None)
15    e = method(normal_distribution)
16
17    f = d * e * e
18    return f, c

Note how we only changed the line (method = Reparameterization("e", n_samples=n)) where we defined the gradient estimation method to now create a storch.method.ScoreFunction instead of storch.method.Reparameterization. Let’s see the variance of this method (using 1 sample):

>>> storch.variance(gradients, "gradients")
Deterministic tensor(748.1914) Batch links: []

Ouch, that really is much higher than using Reparameterization! While the score function is much more generally applicable than reparameterization (as it can be used for discrete distributions and non-differentiable functions), it clearly has a prohibitive large variance. Storchastic also has the storch.method.Infer gradient estimation method, which automatically applies reparameterization if possible and otherwise uses the score function.

Can we do something about the large variance? Using more samples is always an option. To get the variance in the same ballpark as a single-sample reparameterization, we would need to use about 748.2/16.7 samples, or about n=45!

>>> storch.variance(gradients, "gradients")
Deterministic tensor(17.0591) Batch links: []

Luckily, we can make efficient reuse of the multiple samples we take. Note how we set baseline_factory=None when defining the storch.method.ScoreFunction. A baseline is a very common variance reduction method that subtracts a value from the cost function to stabilize the gradient. A simple but effective one is the batch average baseline (storch.method.baseline.BatchAverage) that subtracts the average of the other samples. Simply change ScoreFunction("e", n_samples=n, baseline_factory="batch_average"). Let’s use 20 samples:

>>> storch.variance(gradients, "gradients")
Deterministic tensor(16.8761) Batch links: []

Sweet! We used fewer than halve of the samples, yet get a lower variance than before. For complicated settings where reparameterization is not an option, strong variance reduction is unfortunately very important for efficient algorithms.

For full code of this example, go to Introduction Example.