Sampling, Inference and Variance Reduction¶
Storchastic allows you to define stochastic computation graphs using an API that resembles generative stories. It is designed with plug-and-play in mind: it is very easy to swap in different gradient estimation methods to compare their performance on your task. In this tutorial, we apply gradient estimation to a simple and small problem.
Converting generative stories¶
We return to the generative story from Stochastic computation graphs:
Compute \(d=a+b\)
Sample \(e\sim \mathcal{N}(c+b, 1)\)
Compute \(f=d\cdot e^2\)
This story is easily converted using the following code:
1import torch
2import storch
3from torch.distributions import Normal
4from storch.method import Reparameterization, ScoreFunction
5
6def compute_f(n):
7 a = torch.tensor(5.0, requires_grad=True)
8 b = torch.tensor(-3.0, requires_grad=True)
9 c = torch.tensor(0.23, requires_grad=True)
10 d = a + b
11
12 # Sample e from a normal distribution using reparameterization
13 normal_distribution = Normal(b + c, 1)
14 method = Reparameterization("e", n_samples=n)
15 e = method(normal_distribution)
16
17 f = d * e * e
18 return f, c
Lines 10 and 17 represent the deterministic nodes. Lines 13-15 represent the stochastic node:
We sample a value from the normal distribution using reparameterization (or the pathwise derivative).
The storch.method.Reparameterization
class is a subclass of storch.method.Method
. Subclasses
implement functionality for sampling and gradient estimation, and you can subclass storch.method.Method
to
implement new methods gradient estimation methods. Furthermore, storch.method.Method
subclasses
torch.nn.Module
, which makes it easy for them to become part of a PyTorch model.
storch.method.Reparameterization
is initialized with the variable name “e”. This is done to initialize the
plate that corresponds to this sample. We will introduce plates later on. Furthermore, they have an optional
n_samples option, which controls how many samples are taken from the normal distribution. Note that the method is called directly on the distribution (torch.distributions.Distribution
) to sample from.
Gradient estimation¶
Great. Now how to get the derivative with respect to \(c\)? Storchastic requires you to register cost nodes using storch.add_cost()
. These are
leave nodes that will be minimized. When all cost nodes are registered, storch.backward()
is used to estimate
the gradients:
>>> f, c = compute_f(1)
>>> storch.add_cost(f, "f")
>>> storch.backward()
tensor(3.0209, grad_fn=<AddBackward0>)
>>> c.grad
tensor(-4.9160)
The second line registers the cost node with the name “f”, and the third line computes the gradients, where PyTorch’s automatic
differentiation is used for deterministic nodes, and Storchastic’s gradient estimation methods for stochastic nodes.
storch.backward()
returns the estimated value of the sum of cost nodes, which in this case is just \(f\).
We also show the estimated gradient with respect to \(c\) (-4.9160). Note that this gradient is stochastic! Running the code another time, we get -12.2537.
Computing gradient statistics¶
We can estimate the mean and variance of the gradient as follows:
19n = 1
20gradient_samples = []
21for i in range(1000):
22 f, c = compute_f(n)
23 storch.add_cost(f, "f")
24 storch.backward()
25 gradient_samples.append(c.grad)
26gradients = storch.gather_samples(gradient_samples, "gradients")
>>> storch.variance(gradients, "gradients")
Deterministic tensor(16.7321) Batch links: []
>>> print(storch.reduce_plates(gradients, "gradients"))
Deterministic tensor(-11.0195) Batch links: []
Alright, a few things to note. storch.gather_samples()
is a function that takes a list of tensors that are (conditionally)
independent samples of some value, in this case the gradients. Like most other methods in Storchastic, it returns a
storch.Tensor
, in this case a storch.IndependentTensor
:
>>> type(gradients)
<class 'storch.tensor.IndependentTensor'>
storch.Tensor
is a special “tensor-like” object which wraps a torch.Tensor
and includes extra metadata
to help with estimating gradients and keeping track of the plate dimensions. Plate dimensions are dimensions of the tensor
of which we know conditional independency properties. We can look at the plate dimensions of a storch.Tensor
using
>>> gradients.plates
[('gradients', 1000, tensor(0.0010))]
The gradients tensor has one plate dimension with name “gradients” (as we defined using storch.gather_samples()
).
As we simulated the gradient 1000 times, the size of the plate dimension is 1000. The third value is the weight of the
samples. In this case, samples are weighted identically (that is, the weight is 1/1000), which corresponds to a normal
monte carlo sample.
Note that we used the plate dimension name “gradients” in storch.variance(gradients, "gradients")
.
With this we mean that we compute the variance over the gradient plate dimension, which represent the
different independent samples of gradient estimates.
Reducing variance¶
Next, let us try to reduce the variance. A simple way to do this is to use more samples of \(e\).
In line 14 (method = Reparameterization("e", n_samples=n)
, we pass the amount of samples to use
for this method. Let’s use 10 by setting line 19 to n = 10
, and compute the variance again:
>>> storch.variance(gradients, "gradients")
Deterministic tensor(1.6388) Batch links: []
By using 10 times as many samples, we reduced the variance by (about) a factor 10. Note that we did not have to change any other code but changing the value of n. Storchastic is designed so that all (left-broadcastable!) code supports both using a single or multiple samples. Using more samples is an easy way to reduce variance. Storchastic automatically parallelizes the computation over the different samples, so that if your gpu has enough memory, there is (usually) almost no overhead to using more samples, yet we can get better estimates of the gradient!
Using different estimators¶
Storchastic is designed to make swapping in different gradient estimation as easy as possible. For instance, say we want to use the score function instead of reparameterization. This is done as follows:
6def compute_f(n):
7 a = torch.tensor(5.0, requires_grad=True)
8 b = torch.tensor(-3.0, requires_grad=True)
9 c = torch.tensor(0.23, requires_grad=True)
10 d = a + b
11
12 # Sample e from a normal distribution using reparameterization
13 normal_distribution = Normal(b + c, 1)
14 method = ScoreFunction("e", n_samples=n, baseline_factory=None)
15 e = method(normal_distribution)
16
17 f = d * e * e
18 return f, c
Note how we only changed the line (method = Reparameterization("e", n_samples=n)
) where we defined the gradient
estimation method to now create a storch.method.ScoreFunction
instead of storch.method.Reparameterization
.
Let’s see the variance of this method (using 1 sample):
>>> storch.variance(gradients, "gradients")
Deterministic tensor(748.1914) Batch links: []
Ouch, that really is much higher than using Reparameterization! While the score function is much more generally applicable
than reparameterization (as it can be used for discrete distributions and non-differentiable functions), it clearly has
a prohibitive large variance. Storchastic also has the storch.method.Infer
gradient estimation method,
which automatically applies reparameterization if possible and otherwise uses the score function.
Can we do something about the large variance? Using more samples is always an option. To get the variance in the same ballpark as a single-sample reparameterization, we would need to use about 748.2/16.7 samples, or about n=45!
>>> storch.variance(gradients, "gradients")
Deterministic tensor(17.0591) Batch links: []
Luckily, we can make efficient reuse of the multiple samples we take. Note how we set baseline_factory=None
when defining the storch.method.ScoreFunction
. A baseline is a very common variance reduction method that
subtracts a value from the cost function to stabilize the gradient. A simple but effective one is the batch average
baseline (storch.method.baseline.BatchAverage
) that subtracts the average of the other samples. Simply change
ScoreFunction("e", n_samples=n, baseline_factory="batch_average")
. Let’s use 20 samples:
>>> storch.variance(gradients, "gradients")
Deterministic tensor(16.8761) Batch links: []
Sweet! We used fewer than halve of the samples, yet get a lower variance than before. For complicated settings where reparameterization is not an option, strong variance reduction is unfortunately very important for efficient algorithms.
For full code of this example, go to Introduction Example.