Source code for storch.sampling.expect

from typing import Optional

from storch.sampling import SamplingMethod
from torch.distributions import Distribution
import storch
import torch
from storch import Plate
import itertools

[docs]class Enumerate(SamplingMethod): def __init__(self, plate_name: str, budget=10000): super().__init__(plate_name) self.budget = budget
[docs] def sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool, ) -> (storch.StochasticTensor, Plate): # TODO: Currently very inefficient as it isn't batched # TODO: What if the expectation has a parent if not distr.has_enumerate_support: raise ValueError( "Can only calculate the expected value for distributions with enumerable support." ) support: torch.Tensor = distr.enumerate_support(expand=True) support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False) expect_size = support.shape[0] batch_len = len(plates) sizes = support.shape[ batch_len + 1 : len(support.shape) - len(distr.event_shape) ] amt_samples_used = expect_size cross_products = 1 if not sizes else None for dim in sizes: amt_samples_used = amt_samples_used ** dim if not cross_products: cross_products = dim else: cross_products = cross_products ** dim if amt_samples_used > self.budget: raise ValueError( "Computing the expectation on this distribution would exceed the computation budget." ) enumerate_tensor = support.new_zeros( [amt_samples_used] + list(support.shape[1:]) ) support_non_expanded = support_non_expanded.squeeze().unsqueeze(1) for i, t in enumerate( itertools.product(support_non_expanded, repeat=cross_products) ): enumerate_tensor[i] =, dim=0) enumerate_tensor = enumerate_tensor.detach() plate_size = enumerate_tensor.shape[0] plate = Plate(self.plate_name, plate_size, plates.copy()) plates.insert(0, plate) s_tensor = storch.StochasticTensor( enumerate_tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad, ) return s_tensor, plate
[docs] def weighting_function( self, tensor: storch.StochasticTensor, plate: Plate ) -> Optional[storch.Tensor]: # Weight by the probability of each possible event log_probs = tensor.distribution.log_prob(tensor) if log_probs.plate_dims < len(log_probs.shape): log_probs = log_probs.sum( dim=list(range(tensor.plate_dims, len(log_probs.shape))) ) return log_probs.exp()