Sampling Methods¶
- class storch.sampling.method.MonteCarlo(plate_name: str, n_samples: int = 1)[source]¶
Bases:
SamplingMethod
Monte Carlo sampling methods use simple sampling methods that take n independent samples. Unlike complex ancestral sampling methods such as SampleWithoutReplacementMethod, the sampling behaviour is not dependent on earlier samples in the stochastic computation graph (but the distributions are!).
- sample(distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool)[source]¶
- training: bool¶
- class storch.sampling.method.SamplingMethod(plate_name: str)[source]¶
Bases:
ABC
,Module
- forward(distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- mc_sample(distr: Distribution, parents: [storch.Tensor], plates: [Plate], amt_samples: int) torch.Tensor [source]¶
- abstract sample(distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool)[source]¶
- set_mc_sample(new_sample_func: Callable[[Distribution, [storch.Tensor], [Plate], int], torch.Tensor]) SamplingMethod [source]¶
Override storch.Method specific sampling functions. This is called when initializing a storch.Method that has slightly different MC sampling semantics (for example, reparameterization instead of normal sampling). This allows for compatibility of different storch.Method’s with different storch.sampling.Method’s.
- set_mc_weighting_function(new_weighting_func: Callable[[StochasticTensor, Plate], Optional[Tensor]]) SamplingMethod [source]¶
Override storch.Method specific weighting functions. This is called when initializing a storch.Method that has slightly different MC weighting semantics (for example, REBAR that weights some samples differently). This allows for compatibility of different storch.Method’s with different storch.sampling.Method’s.
- training: bool¶
- update_parameters(result_triples: [storch.StochasticTensor, storch.CostTensor, torch.Tensor])[source]¶
- weighting_function(tensor: StochasticTensor, plate: Plate) Optional[Tensor] [source]¶
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting of the different events, like importance sampling or computing the expectation. If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates
Expectation¶
- class storch.sampling.expect.Enumerate(plate_name: str, budget=10000)[source]¶
Bases:
SamplingMethod
- sample(distr: ~torch.distributions.distribution.Distribution, parents: [<class 'storch.tensor.Tensor'>], plates: [<class 'storch.tensor.Plate'>], requires_grad: bool) -> (<class 'storch.tensor.StochasticTensor'>, <class 'storch.tensor.Plate'>)[source]¶
- training: bool¶
- weighting_function(tensor: StochasticTensor, plate: Plate) Optional[Tensor] [source]¶
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting of the different events, like importance sampling or computing the expectation. If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates
Sequence decoding¶
- class storch.sampling.seq.AncestralPlate(name: str, n: int, parents: List[Plate], variable_index: int, parent_plate: AncestralPlate, selected_samples: Optional[Tensor], log_probs: Optional[Tensor], weight: Optional[Tensor] = None)[source]¶
Bases:
Plate
- on_collecting_args(plates: [storch.Plate]) bool [source]¶
Filter the collected plates to only keep the AncestralPlates (with the same name) that has the highest variable index. :param plates: :return:
- on_unwrap_tensor(tensor: Tensor) Tensor [source]¶
Gets called whenever the given tensor is being unwrapped and unsqueezed for batch use. This method should not be called on tensors whose variable index is higher than this plates.
selected_samples is used to choose from the parent plates what is the previous element in the sequence. This is for example used in sampling without replacement. If set to None, it is assumed the different sequences are indexed by the plate dimension.
- Parameters
tensor – The input tensor that is being unwrapped
- Returns
The tensor that will be unwrapped and unsqueezed in the future. Can be a modification of the input tensor.
- class storch.sampling.seq.IterDecoding(plate_name, k, eos)[source]¶
Bases:
SequenceDecoding
- decode(distr: Distribution, joint_log_probs: Optional[storch.Tensor], parents: [storch.Tensor], orig_distr_plates: [storch.Plate])[source]¶
Decode given the input arguments :param distribution: The distribution to decode :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples :param parents: List of parents of this tensor :param orig_distr_plates: List of plates from the distribution. Can include the self plate k. :return: 3-tuple of storch.Tensor. 1: The sampled value. 2: The new joint log probabilities of the samples. 3: How the samples index the parent samples. Can just be None if there is no choosing happening.
- abstract decode_step(indices: Tuple[int], yv_log_probs: storch.Tensor, joint_log_probs: Optional[storch.Tensor], sampled_support_indices: Optional[storch.Tensor], parent_indexing: Optional[storch.Tensor], is_conditional_sample: bool, amt_plates: int, amt_samples: int)[source]¶
Decode given the input arguments for a specific event :param indices: Tuple of integers indexing the current event to sample. :param yv_log_probs: Log probabilities of the different options for this event. distr_plates x k? x |D_yv| :param joint_log_probs: The log probabilities of the samples so far. None if not is_conditional_sample. prev_plates x amt_samples :param sampled_support_indices: Tensor of samples so far. None if this is the first set of indices. plates x k x events :param parent_indexing: Tensor indexing the parent sample. None if not is_conditional_sample. :param is_conditional_sample: True if a parent has already been sampled. This means the plates are more complex! :param amt_plates: The total amount of plates in both the distribution and the previously sampled variables :param amt_samples: The amount of active samples. :return: 3-tuple of storch.Tensor. 1: sampled_support_indices, with :, indices referring to the indices for the support. 2: The updated joint_log_probs of the samples. 3: The updated parent_indexing. How the samples index the parent samples. Can just return parent_indexing if nothing happens. 4: The amount of active samples after this step.
- training: bool¶
- class storch.sampling.seq.MCDecoder(plate_name: str, k: int, eos: None)[source]¶
Bases:
SequenceDecoding
- decode(distribution: Distribution, joint_log_probs: Optional[storch.Tensor], parents: [storch.Tensor], orig_distr_plates: [storch.Plate])[source]¶
Decode given the input arguments :param distribution: The distribution to decode :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples :param parents: List of parents of this tensor :param orig_distr_plates: List of plates from the distribution. Can include the self plate k. :return: 3-tuple of storch.Tensor. 1: The sampled value. 2: The new joint log probabilities of the samples. 3: How the samples index the parent samples. Can just be a range if there is no choosing happening. For all of these, the last plate index should be the plate index, with the other plates like all_plates
- training: bool¶
- class storch.sampling.seq.SequenceDecoding(plate_name: str, k: int, eos: None)[source]¶
Bases:
SamplingMethod
Methods for generating sequences of discrete random variables. Examples: Simple ancestral sampling with replacement, beam search, Stochastic beam search (sampling without replacement)
- EPS = 1e-08¶
- create_plate(plate_size: int, plates: [storch.Plate]) AncestralPlate [source]¶
- abstract decode(distribution: Distribution, joint_log_probs: Optional[storch.Tensor], parents: [storch.Tensor], orig_distr_plates: [storch.Plate])[source]¶
Decode given the input arguments :param distribution: The distribution to decode :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples :param parents: List of parents of this tensor :param orig_distr_plates: List of plates from the distribution. Can include the self plate k. :return: 3-tuple of storch.Tensor. 1: The sampled value. 2: The new joint log probabilities of the samples. 3: How the samples index the parent samples. Can just be a range if there is no choosing happening. For all of these, the last plate index should be the plate index, with the other plates like all_plates
- sample(distr: Distribution, parents: [storch.Tensor], orig_distr_plates: [storch.Plate], requires_grad: bool)[source]¶
Sample from the distribution given the sequence so far. :param distribution: The distribution to sample from :return:
- training: bool¶
- weighting_function(tensor: StochasticTensor, plate: Plate) Optional[Tensor] [source]¶
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting of the different events, like importance sampling or computing the expectation. If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates
- storch.sampling.seq.expand_with_ignore_as(tensor, expand_as, ignore_dim: Union[str, int]) Tensor [source]¶
Expands the tensor like expand_as, but ignores a single dimension. Ie, if tensor is of size a x b, expand_as of size d x a x c and dim=-1, then the return will be of size d x a x b. It also automatically expands all plate dimensions correctly. :param ignore_dim: Can be a string referring to the plate dimension
Sampling without replacement¶
- class storch.sampling.swor.SampleWithoutReplacement(plate_name: str, k: int, biased_iw: bool = False, eos=None)[source]¶
Bases:
IterDecoding
Sampling method for sampling without replacement from (sequences of) discrete distributions. Implements Stochastic Beam Search https://arxiv.org/abs/1903.06059 with the weighting as defined by REINFORCE without replacement https://openreview.net/forum?id=r1lgTGL5DE
- EPS = 1e-08¶
- compute_iw(plate: AncestralPlate, biased: bool)[source]¶
- create_plate(plate_size: int, plates: [storch.Plate]) AncestralPlate [source]¶
- decode_step(indices: Tuple[int], yv_log_probs: storch.Tensor, joint_log_probs: Optional[storch.Tensor], sampled_support_indices: Optional[storch.Tensor], parent_indexing: Optional[storch.Tensor], is_conditional_sample: bool, amt_plates: int, amt_samples: int)[source]¶
Decode given the input arguments for a specific event using stochastic beam search. :param indices: Tuple of integers indexing the current event to sample. :param yv_log_probs: Log probabilities of the different options for this event. distr_plates x k? x |D_yv| :param joint_log_probs: The log probabilities of the samples so far. None if not is_conditional_sample. prev_plates x amt_samples :param sampled_support_indices: Tensor of samples so far. None if this is the first set of indices. plates x k x events :param parent_indexing: Tensor indexing the parent sample. None if not is_conditional_sample. :param is_conditional_sample: True if a parent has already been sampled. This means the plates are more complex! :param amt_plates: The total amount of plates in both the distribution and the previously sampled variables :param amt_samples: The amount of active samples. :return: 3-tuple of storch.Tensor. 1: sampled_support_indices, with :, indices referring to the indices for the support. 2: The updated joint_log_probs of the samples. 3: The updated parent_indexing. How the samples index the parent samples. Can just return parent_indexing if nothing happens. 4: The amount of active samples after this step.
- perturbed_log_probs: Optional[storch.Tensor] = None¶
- select_samples(perturbed_log_probs: storch.Tensor, joint_log_probs: storch.Tensor)[source]¶
Given the perturbed log probabilities and the joint log probabilities of the new options, select which one to use for the sample. :param perturbed_log_probs: plates x (k? * |D_yv|). Perturbed log-probabilities. k is present if first_sample. :param joint_log_probs: plates x (k? * |D_yv|). Joint log probabilities of the options. k is present if first_sample. :param first_sample: :return: perturbed log probs of chosen samples, joint log probs of chosen samples, index of chosen samples
- set_mc_sample(new_sample_func: Callable[[Distribution, [storch.Tensor], [storch.Plate], int], torch.Tensor]) SamplingMethod [source]¶
Override storch.Method specific sampling functions. This is called when initializing a storch.Method that has slightly different MC sampling semantics (for example, reparameterization instead of normal sampling). This allows for compatibility of different storch.Method’s with different storch.sampling.Method’s.
- weighting_function(tensor: StochasticTensor, plate: Plate) Optional[Tensor] [source]¶
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting of the different events, like importance sampling or computing the expectation. If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates
- class storch.sampling.swor.SumAndSample(plate_name: str, sum_size: int, sample_size: int = 1, without_replacement: bool = False, eos=None)[source]¶
Bases:
SampleWithoutReplacement
Sums over S probable samples according to beam search and K sampled values that are not in the probable samples, then normalizes them accordingly.
- select_samples(perturbed_log_probs: storch.Tensor, joint_log_probs: storch.Tensor)[source]¶
Given the perturbed log probabilities and the joint log probabilities of the new options, select which one to use for the sample. :param perturbed_log_probs: plates x (k? * |D_yv|). Perturbed log-probabilities. k is present if first_sample. :param joint_log_probs: plates x (k? * |D_yv|). Joint log probabilities of the options. k is present if first_sample. :param first_sample: :return: perturbed log probs of chosen samples, joint log probs of chosen samples, index of chosen samples
- training: bool¶
- storch.sampling.swor.log1mexp(a: Tensor) Tensor [source]¶
See appendix A of http://jmlr.org/papers/v21/19-985.html. Numerically stable implementation of log(1-exp(a))
Unordered set sampling¶
- class storch.sampling.unordered_set.GumbelSoftmaxWOR(plate_name: str, k: int, initial_temperature=1.0, min_temperature=0.0001, annealing_rate=1e-05, eos=None)[source]¶
Bases:
UnorderedSet
- sample(distr: ~torch.distributions.distribution.Distribution, parents: [<class 'storch.tensor.Tensor'>], orig_distr_plates: [<class 'storch.tensor.Plate'>], requires_grad: bool) -> (<class 'torch.Tensor'>, <class 'storch.tensor.Plate'>)[source]¶
Sample from the distribution given the sequence so far. :param distribution: The distribution to sample from :return:
- training: bool¶
- class storch.sampling.unordered_set.UnorderedSet(plate_name: str, k: int, comp_leave_two_out: bool = False, exact_integration: bool = False, num_int_points: int = 1000, a: float = 5.0, eos=None)[source]¶
Bases:
SampleWithoutReplacement
- training: bool¶
- weighting_function(tensor: StochasticTensor, plate: Plate) Optional[Tensor] [source]¶
Weight by the size of the sample. Overload this if your sampling method uses some kind of weighting of the different events, like importance sampling or computing the expectation. If None is returned, it is assumed the samples are iid monte carlo samples.
This method is called from storch.method.Method.sample, and it is not needed to manually call this on created plates