import torch
from torch.distributions import Normal
from storch.method import Reparameterization, ScoreFunction
import storch
torch.manual_seed(0)
def compute_f(method):
a = torch.tensor(5.0, requires_grad=True)
b = torch.tensor(-3.0, requires_grad=True)
c = torch.tensor(0.23, requires_grad=True)
d = a + b
# Sample e from a normal distribution using reparameterization
normal_distribution = Normal(b + c, 1)
e = method(normal_distribution)
f = d * e * e
return f, c
# e*e follows a noncentral chi-squared distribution https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution
# exp_f = d * (1 + mu * mu)
repar = Reparameterization("e", n_samples=1)
f, c = compute_f(repar)
storch.add_cost(f, "f")
print(storch.backward())
print("first derivative estimate", c.grad)
f, c = compute_f(repar)
storch.add_cost(f, "f")
print(storch.backward())
print("second derivative estimate", c.grad)
def estimate_variance(method):
gradient_samples = []
for i in range(1000):
f, c = compute_f(method)
storch.add_cost(f, "f")
storch.backward()
gradient_samples.append(c.grad)
gradients = storch.gather_samples(gradient_samples, "gradients")
# print(gradients)
print("variance", storch.variance(gradients, "gradients"))
print("mean", storch.reduce_plates(gradients, "gradients"))
print("st dev", torch.sqrt(storch.variance(gradients, "gradients")))
print(type(gradients))
print(gradients.shape)
print(gradients.plates)
print("Reparameterization n=1")
estimate_variance(Reparameterization("e", n_samples=1))
print("Reparameterization n=10")
estimate_variance(Reparameterization("e", n_samples=10))
print("Score function n=1")
estimate_variance(ScoreFunction("e", n_samples=1))
print("Score function n=45")
estimate_variance(ScoreFunction("e", n_samples=45))
print("Score function with baseline n=20")
estimate_variance(ScoreFunction("e", n_samples=20, baseline_factory="batch_average"))