Source code for storch.sampling.unordered_set

from typing import Optional

from torch.distributions import Distribution

from storch.sampling.swor import log1mexp, SampleWithoutReplacement
import storch

import torch


[docs]class UnorderedSet(SampleWithoutReplacement): def __init__( self, plate_name: str, k: int, comp_leave_two_out: bool = False, exact_integration: bool = False, num_int_points: int = 1000, a: float = 5.0, eos=None, ): super().__init__(plate_name, k, eos=eos) self.comp_leave_two_out = comp_leave_two_out self.exact_integration = exact_integration self.num_int_points = num_int_points self.a = a
[docs] def weighting_function( self, tensor: storch.StochasticTensor, plate: storch.Plate ) -> Optional[storch.Tensor]: # plates_w_k is a sequence of plates of which one is the input ancestral plate # Computes p(s) * R(S^k, s), or the probability of the sample times the leave-one-out ratio. # For details, see https://openreview.net/pdf?id=rklEj2EFvB # Code based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py log_probs = plate.log_probs.detach() # Compute integration points for the trapezoid rule: v should range from 0 to 1, where both v=0 and v=1 give a value of 0. # As the computation happens in log-space, take the logarithm of the result. # N v = ( torch.arange(1, self.num_int_points, out=log_probs._tensor.new()) / self.num_int_points ) log_v = v.log() # Compute log(1-v^{exp(log_probs+a)}) in a numerically stable way in log-space # Uses the gumbel_log_survival function from # https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py # plates_w_k x N g_bound = ( log_probs[..., None] + self.a + torch.log(-log_v)[log_probs.plate_dims * (None,) + (slice(None),)] ) # Gumbel log survival: log P(g > g_bound) = log(1 - exp(-exp(-g_bound))) for standard gumbel g # If g_bound >= 10, use the series expansion for stability with error O((e^-10)^6) (=8.7E-27) # See https://www.wolframalpha.com/input/?i=log%281+-+exp%28-y%29%29 y = torch.exp(g_bound) # plates_w_k x N terms = torch.where( g_bound >= 10, -g_bound - y / 2 + y ** 2 / 24 - y ** 4 / 2880, log1mexp(y) ) # Compute integrands (without subtracting the special value s) # plates x N sum_of_terms = storch.sum(terms, plate) phi_S = storch.logsumexp(log_probs, plate) phi_D_min_S = log1mexp(phi_S) # plates x N integrand = ( sum_of_terms + torch.expm1(self.a + phi_D_min_S)[..., None] * log_v[phi_D_min_S.plate_dims * (None,) + (slice(None),)] ) # Subtract one term the for element that is left out in R # Automatically unsqueezes correctly using plate dimensions # plates_w_k x N integrand_without_s = integrand - terms # plates log_p_S = integrand.logsumexp(dim=-1) # plates_w_k log_p_S_without_s = integrand_without_s.logsumexp(dim=-1) # plates_w_k log_leave_one_out = log_p_S_without_s - log_p_S if self.comp_leave_two_out: # Compute the integrands for the 2nd order leave one out ratio. # Make sure to properly choose the indices: We shouldn't subtract the same term twice on the diagonals. # k x k skip_diag = storch.Tensor( 1 - torch.eye(plate.n, out=log_probs._tensor.new()), [], [plate] ) # plates_w_k x k x N integrand_without_ss = ( integrand_without_s[..., None, :] - terms[..., None, :] * skip_diag[..., None] ) # plates_w_k x k log_p_S_without_ss = integrand_without_ss.logsumexp(dim=-1) plate.log_snd_leave_one_out = log_p_S_without_ss - log_p_S_without_s # Return the unordered set estimator weighting return (log_leave_one_out + log_probs).exp().detach()
[docs]class GumbelSoftmaxWOR(UnorderedSet): def __init__( self, plate_name: str, k: int, initial_temperature=1.0, min_temperature=1.0e-4, annealing_rate=1.0e-5, eos=None, ): super().__init__(plate_name, k, comp_leave_two_out=False, eos=eos) self.temperature = initial_temperature
[docs] def sample( self, distr: Distribution, parents: [storch.Tensor], orig_distr_plates: [storch.Plate], requires_grad: bool, ) -> (torch.Tensor, storch.Plate): hard_sample, plate = super().sample( distr, parents, orig_distr_plates, requires_grad ) from storch import conditional_gumbel_rsample gumbel_wor = conditional_gumbel_rsample(hard_sample, distr.probs, isinstance(distr, torch.distributions.Bernoulli), self.temperature) gumbel_wor = storch.StochasticTensor( gumbel_wor._tensor, hard_sample.parents, hard_sample.plates, hard_sample.name, self.k, distr, requires_grad, ) return gumbel_wor, plate