Source code for storch.sampling.swor

from __future__ import annotations
from typing import Optional, Callable, Tuple

import storch
import torch
from torch.distributions import Distribution, Gumbel
from storch.sampling.method import SamplingMethod
from storch.sampling.seq import IterDecoding, AncestralPlate, right_expand_as


[docs]class SampleWithoutReplacement(IterDecoding): """ Sampling method for sampling without replacement from (sequences of) discrete distributions. Implements Stochastic Beam Search https://arxiv.org/abs/1903.06059 with the weighting as defined by REINFORCE without replacement https://openreview.net/forum?id=r1lgTGL5DE """ EPS = 1e-8 perturbed_log_probs: Optional[storch.Tensor] = None def __init__(self, plate_name: str, k: int, biased_iw: bool = False, eos=None): super().__init__(plate_name, k, eos) if k < 2: raise ValueError( "Can only sample with replacement for more than 1 samples." ) self.biased_iw = biased_iw
[docs] def reset(self): super().reset() # Cumulative perturbed log probabilities of the samples self.perturbed_log_probs = None
[docs] def decode_step( self, indices: Tuple[int], yv_log_probs: storch.Tensor, joint_log_probs: Optional[storch.Tensor], sampled_support_indices: Optional[storch.Tensor], parent_indexing: Optional[storch.Tensor], is_conditional_sample: bool, amt_plates: int, amt_samples: int, ) -> (storch.Tensor, storch.Tensor, storch.Tensor): """ Decode given the input arguments for a specific event using stochastic beam search. :param indices: Tuple of integers indexing the current event to sample. :param yv_log_probs: Log probabilities of the different options for this event. distr_plates x k? x |D_yv| :param joint_log_probs: The log probabilities of the samples so far. None if `not is_conditional_sample`. prev_plates x amt_samples :param sampled_support_indices: Tensor of samples so far. None if this is the first set of indices. plates x k x events :param parent_indexing: Tensor indexing the parent sample. None if `not is_conditional_sample`. :param is_conditional_sample: True if a parent has already been sampled. This means the plates are more complex! :param amt_plates: The total amount of plates in both the distribution and the previously sampled variables :param amt_samples: The amount of active samples. :return: 3-tuple of `storch.Tensor`. 1: sampled_support_indices, with `:, indices` referring to the indices for the support. 2: The updated `joint_log_probs` of the samples. 3: The updated `parent_indexing`. How the samples index the parent samples. Can just return parent_indexing if nothing happens. 4: The amount of active samples after this step. """ first_sample = False if joint_log_probs is None: # We also know that k? is not present, so distr_plates x |D_yv| all_joint_log_probs = yv_log_probs # First condition on max being 0: self.perturbed_log_probs = 0.0 first_sample = True elif is_conditional_sample > 0: # Make sure we are selecting the correct log-probabilities. As parents have been selected, this might change! # plates x amt_samples x |D_yv| yv_log_probs = yv_log_probs.gather( dim=-2, index=right_expand_as( # Use the parent_indexing to select the correct plate samples. Make sure to limit to amt_samples! parent_indexing[..., :amt_samples], yv_log_probs, ), ) # self.joint_log_probs: prev_plates x amt_samples # plates x amt_samples x |D_yv| all_joint_log_probs = joint_log_probs.unsqueeze(-1) + yv_log_probs else: # self.joint_log_probs: plates x amt_samples # plates x amt_samples x |D_yv| all_joint_log_probs = joint_log_probs.unsqueeze( -1 ) + yv_log_probs.unsqueeze(-2) # Sample plates x k? x |D_yv| conditional Gumbel variables cond_G_yv = cond_gumbel_sample(all_joint_log_probs, self.perturbed_log_probs) # If there are finished samples, ensure eos is always sampled. if self.finished_samples is not None: # TODO: Is this the correct way of ensuring self.eos is always sampled for finished sequences? # Coudl it bias things in any way? # Set the probability of continuing on finished sequences to -infinity so that they are filtered out during topk. # amt_finished finished_perturb_log_probs = self.perturbed_log_probs._tensor[ self.finished_samples._tensor ] # amt_finished x |D_yv| finished_vec = finished_perturb_log_probs.new_full( (finished_perturb_log_probs.shape[0], cond_G_yv.shape[-1],), -float("inf"), ) # Then make sure the log probability of the eos token is equal to the last perturbed log prob. finished_vec[:, self.eos] = finished_perturb_log_probs cond_G_yv[self.finished_samples] = finished_vec if not first_sample: # plates x (k * |D_yv|) (k == prev_amt_samples, in this case) cond_G_yv = cond_G_yv.reshape(cond_G_yv.shape[:-2] + (-1,)) # Reshape log probs to plates x (k * |D_yv|). Matches perturbed shape. all_joint_log_probs = all_joint_log_probs.reshape(cond_G_yv.shape) # Select the samples given the perturbed log probabilities self.perturbed_log_probs, joint_log_probs, arg_top = self.select_samples( cond_G_yv, all_joint_log_probs ) # Gather corresponding joint log probabilities. amt_samples = arg_top.shape[-1] if first_sample: # plates x amt_samples # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv| # (:) * amt_plates + (indices for events) + amt_samples indexing = (slice(None),) * amt_plates + (slice(0, amt_samples),) + indices sampled_support_indices[indexing] = arg_top else: joint_log_probs = all_joint_log_probs.reshape(cond_G_yv.shape).gather( dim=-1, index=arg_top ) # |D_yv| size_domain = yv_log_probs.shape[-1] # Keep track of what parents were sampled for the arg top # plates x amt_samples chosen_parents = arg_top // size_domain sampled_support_indices = sampled_support_indices.gather( dim=amt_plates, index=right_expand_as(chosen_parents, sampled_support_indices), ) if parent_indexing is not None: parent_indexing = parent_indexing.gather(dim=-1, index=chosen_parents) # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv| # plates x amt_samples chosen_samples = arg_top.remainder(size_domain) indexing = (slice(None),) * amt_plates + (slice(0, amt_samples),) + indices sampled_support_indices[indexing] = chosen_samples return sampled_support_indices, joint_log_probs, parent_indexing, amt_samples
[docs] def select_samples( self, perturbed_log_probs: storch.Tensor, joint_log_probs: storch.Tensor, ) -> (storch.Tensor, storch.Tensor, storch.Tensor): """ Given the perturbed log probabilities and the joint log probabilities of the new options, select which one to use for the sample. :param perturbed_log_probs: plates x (k? * |D_yv|). Perturbed log-probabilities. k is present if first_sample. :param joint_log_probs: plates x (k? * |D_yv|). Joint log probabilities of the options. k is present if first_sample. :param first_sample: :return: perturbed log probs of chosen samples, joint log probs of chosen samples, index of chosen samples """ # We can sample at most the amount of what we previous sampled, combined with every option in the current domain # That is: prev_amt_samples * |D_yv|. amt_samples = min(self.k, perturbed_log_probs.shape[-1]) # Take the top k over conditional perturbed log probs # plates x amt_samples perturbed_log_probs, arg_top = torch.topk(perturbed_log_probs, amt_samples, dim=-1) joint_log_probs = joint_log_probs.gather(dim=-1, index=arg_top) return perturbed_log_probs, joint_log_probs, arg_top
[docs] def create_plate(self, plate_size: int, plates: [storch.Plate]) -> AncestralPlate: plate = super().create_plate(plate_size, plates) plate.perturb_log_probs = storch.Tensor( self.perturbed_log_probs._tensor, [self.perturbed_log_probs], self.perturbed_log_probs.plates + [plate], ) return plate
[docs] def weighting_function( self, tensor: storch.StochasticTensor, plate: storch.Plate ) -> Optional[storch.Tensor]: # TODO: Doesnt take into account eos tokens # TODO: Does this add the plate to the weighting function result? Could be a big bug! return self.compute_iw(plate, self.biased_iw).detach()
[docs] def compute_iw(self, plate: AncestralPlate, biased: bool): k = plate.perturb_log_probs.shape[-1] # Compute importance weights. The kth sample has 0 weight, and is only used to compute the importance weights # Equation 5 q = (1 - torch.exp( - torch.exp( plate.log_probs - plate.perturb_log_probs._tensor[..., k - 1].unsqueeze(-1) ))).detach() iw = plate.log_probs.exp() / (q + self.EPS) # Set the weight of the kth sample (kappa) to 0. iw[..., k - 1] = 0.0 if biased: # Equation 6 (normalization of importance weights) WS = storch.sum(iw, plate).detach() return iw / WS return iw
[docs] def on_plate_already_present(self, plate: storch.Plate): if ( not isinstance(plate, AncestralPlate) or plate.variable_index > self.variable_index or plate.n > self.k ): super().on_plate_already_present(plate)
[docs] def set_mc_sample( self, new_sample_func: Callable[ [Distribution, [storch.Tensor], [storch.Plate], int], torch.Tensor ], ) -> SamplingMethod: raise RuntimeError( "Cannot set monte carlo sampling for sampling without replacement." )
[docs]def log1mexp(a: torch.Tensor) -> torch.Tensor: """See appendix A of http://jmlr.org/papers/v21/19-985.html. Numerically stable implementation of log(1-exp(a))""" c = -0.693 a1 = -a.abs() eps = 1e-6 # exp_a = -a1.exp() # assert (exp_a >= -1).all() return torch.where(a1 > c, torch.log(-a1.expm1() + eps), torch.log1p(-a1.exp() + eps))
[docs]@storch.deterministic def cond_gumbel_sample(all_joint_log_probs, perturbed_log_probs) -> torch.Tensor: # Sample plates x k? x |D_yv| Gumbel variables gumbel_d = Gumbel(loc=all_joint_log_probs, scale=1.0) G_yv = gumbel_d.rsample() # Condition the Gumbel samples on the maximum of previous samples # plates x k Z = G_yv.max(dim=-1)[0] T = perturbed_log_probs vi = T - G_yv + log1mexp(G_yv - Z.unsqueeze(-1)) # plates (x k) x |D_yv| return T - vi.relu() - torch.nn.Softplus()(-vi.abs())
[docs]class SumAndSample(SampleWithoutReplacement): """ Sums over S probable samples according to beam search and K sampled values that are not in the probable samples, then normalizes them accordingly. """ def __init__( self, plate_name: str, sum_size: int, sample_size: int = 1, without_replacement: bool = False, eos=None, ): super().__init__(plate_name, sum_size + sample_size, eos=eos) self.sum_size = sum_size self.sample_size = sample_size if sum_size < 1 or sample_size < 1: raise ValueError("sum_size and sample_size should both be at least 1.")
[docs] def select_samples( self, perturbed_log_probs: storch.Tensor, joint_log_probs: storch.Tensor, ) -> (storch.Tensor, storch.Tensor): # Select sum_size samples using joint log probs, and sample_size samples using perturbed joint log probs. # We can sample at most the amount of what we previous sampled, combined with every option in the current domain # That is: prev_amt_samples * |D_yv|. amt_sum = min(self.sum_size, joint_log_probs.shape[-1]) # Take the top sum_size over the joint log probs. This is like beam search # plates x amt_samples _, sum_samples = torch.topk(joint_log_probs, amt_sum, dim=-1) sum_perturbed_log_probs = perturbed_log_probs[sum_samples] if amt_sum < self.sum_size: return sum_perturbed_log_probs, sum_samples # Not sure if this is the most efficient implementation # Should be positive, by the previous conditional. amt_sample = min( self.sample_size, perturbed_log_probs.shape[-1] - self.sum_size ) sample_perturbed_log_probs, samples = torch.topk( joint_log_probs, amt_sample + perturbed_log_probs.shape[-1], dim=-1 )
# TODO: This isn't finished yet.