from __future__ import annotations
from typing import Optional, Callable, List
import torch
from torch.distributions import Distribution
from abc import ABC, abstractmethod
import storch
from storch import Plate
[docs]class SamplingMethod(ABC, torch.nn.Module):
def __init__(self, plate_name: str):
super().__init__()
self.reset()
self.plate_name = plate_name
[docs] def reset(self) -> None:
pass
[docs] def forward(
self,
distr: Distribution,
parents: [storch.Tensor],
plates: [Plate],
requires_grad: bool,
) -> (storch.StochasticTensor, Plate):
return self.sample(distr, parents, plates, requires_grad)
[docs] @abstractmethod
def sample(
self,
distr: Distribution,
parents: [storch.Tensor],
plates: [Plate],
requires_grad: bool,
) -> (storch.StochasticTensor, Plate):
pass
[docs] def weighting_function(
self, tensor: storch.StochasticTensor, plate: Plate
) -> Optional[storch.Tensor]:
"""
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting
of the different events, like importance sampling or computing the expectation.
If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates
"""
return self.mc_weighting_function(tensor, plate)
[docs] def mc_weighting_function(
self, tensor: storch.StochasticTensor, plate: Plate
) -> Optional[storch.Tensor]:
return None
[docs] def update_parameters(
self,
result_triples: [(storch.StochasticTensor, storch.CostTensor, torch.Tensor)],
):
pass
[docs] def mc_sample(
self,
distr: Distribution,
parents: [storch.Tensor],
plates: [Plate],
amt_samples: int,
) -> torch.Tensor:
# TODO: Why does this ignore amt_samples?
return distr.sample((amt_samples,))
[docs] def set_mc_sample(
self,
new_sample_func: Callable[
[Distribution, [storch.Tensor], [Plate], int], torch.Tensor
],
) -> SamplingMethod:
"""
Override storch.Method specific sampling functions.
This is called when initializing a storch.Method that has slightly different MC sampling semantics
(for example, reparameterization instead of normal sampling).
This allows for compatibility of different `storch.Method`'s with different `storch.sampling.Method`'s.
"""
self.mc_sample = new_sample_func
return self
[docs] def set_mc_weighting_function(
self,
new_weighting_func: Callable[
[storch.StochasticTensor, Plate], Optional[storch.Tensor]
],
) -> SamplingMethod:
"""
Override storch.Method specific weighting functions.
This is called when initializing a storch.Method that has slightly different MC weighting semantics
(for example, REBAR that weights some samples differently).
This allows for compatibility of different `storch.Method`'s with different `storch.sampling.Method`'s.
"""
self.mc_weighting_function = new_weighting_func
return self
[docs] def on_plate_already_present(self, plate: Plate):
raise ValueError(
"Cannot create stochastic tensor with name "
+ plate.name
+ ". A parent sample has already used this name. Use a different name for this sample."
)
[docs]class MonteCarlo(SamplingMethod):
"""
Monte Carlo sampling methods use simple sampling methods that take n independent samples.
Unlike complex ancestral sampling methods such as SampleWithoutReplacementMethod, the sampling behaviour is not dependent
on earlier samples in the stochastic computation graph (but the distributions are!).
"""
def __init__(self, plate_name: str, n_samples: int = 1):
super().__init__(plate_name)
self.n_samples = n_samples
[docs] def sample(
self,
distr: Distribution,
parents: [storch.Tensor],
plates: [Plate],
requires_grad: bool,
) -> (storch.StochasticTensor, Plate):
plate = None
for _plate in plates:
if _plate.name == self.plate_name:
plate = _plate
break
n_samples = 1 if plate else self.n_samples
with storch.ignore_wrapping():
tensor = self.mc_sample(distr, parents, plates, n_samples)
plate_size = tensor.shape[0]
if tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
if not plate:
plate = Plate(self.plate_name, plate_size, plates.copy())
plates.insert(0, plate)
if isinstance(tensor, storch.Tensor):
tensor = tensor._tensor
s_tensor = storch.StochasticTensor(
tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad,
)
return s_tensor, plate