Storch package

Inference

Wrappers

storch.wrappers.deterministic(fn: Optional[Callable] = None, **kwargs)[source]

Wraps the input function around a deterministic storch wrapper. This wrapper unwraps Tensor objects to Tensor objects, aligning the tensors according to the plates, then runs fn on the unwrapped Tensors.

Parameters
  • fn – Optional function to wrap. If None, this returns another wrapper that accepts a function that will be instantiated

  • kwargs. (by the given) –

  • unwrap – Set to False to prevent unwrapping Tensor objects.

  • fn_args – List of non-keyword arguments to the wrapped function

  • fn_kwargs – Dictionary of keyword arguments to the wrapped function

  • unwrap – Whether to unwrap the arguments to their torch.Tensor counterpart (default: True)

  • align_tensors – Whether to automatically align the input arguments (default: True)

  • l_broadcast – Whether to automatically left-broadcast (default: True)

  • expand_plates – Instead of adding singleton dimensions on non-existent plates, this will

  • (default (Note that outputs are unflattened automatically.) – False) flatten_plates sets this to True automatically.

  • flatten_plates – Flattens the plate dimensions into a single batch dimension if set to true.

  • dimension. (This can be useful for functions that are written to only work for tensors with a single batch) –

  • (default – False)

  • dim – Replaces the dim input in fn_kwargs by the plate dimension corresponding to the given string (optional)

  • dims – Replaces the dims input in fn_kwargs by the plate dimensions corresponding to the given strings (optional)

  • self_wrapper – storch.Tensor that wraps a

Returns

The wrapped function fn.

Return type

Callable

storch.wrappers.ignore_wrapping()[source]
storch.wrappers.is_iterable(a: Any)[source]
storch.wrappers.make_left_broadcastable(fn: Optional[Callable])[source]

Deterministic wrapper that is compatible with functions that are not by themselves left-broadcastable, such as torch.nn.Conv2d(). This function is on (N, C, H, W) and cannot deal with additional ‘independent’ dimensions on the left. To fix this, use make_left_broadcastable(Conv2d(16, 33, 3))

storch.wrappers.reduce(fn, plates: Union[str, List[str]])[source]

Wraps the input function around a deterministic storch wrapper. This wrapper unwraps Tensor objects to Tensor objects, aligning the tensors according to the plates, then runs fn on the unwrapped Tensors. It will reduce the plates given by plates.

Parameters

fn (Callable) – Function to wrap.

Returns

The wrapped function fn.

Return type

Callable

storch.wrappers.stochastic(fn)[source]

Applies fn to the inputs. fn should return one or multiple storch.Tensor`s. `fn should not call storch.stochastic or storch.deterministic. inputs can include `storch.Tensor`s.

Parameters

fn

Returns

Exceptions

exception storch.exceptions.IllegalStorchExposeError(message)[source]

Bases: Exception

Unique

This module is highly experimental

Utilities

storch.storch.cat(*args, **kwargs)[source]

Version of torch.cat() that is compatible with storch.Tensor. Required because torch.Tensor.__torch_function__() is not properly implemented for torch.cat(): https://github.com/pytorch/pytorch/issues/34294

storch.storch.conditional_gumbel_rsample(hard_sample: Tensor, probs: Tensor, bernoulli: bool, temperature) Tensor[source]

Conditionally re-samples from the distribution given the hard sample. This samples z sim p(z|b), where b is the hard sample and p(z) is a gumbel distribution.

storch.storch.expand_as(tensor: Union[Tensor, Tensor], expand_as: Union[Tensor, Tensor]) Union[Tensor, Tensor][source]
storch.storch.gather(input: Tensor, dim: str, index: Tensor)[source]
storch.storch.grad(outputs, inputs, grad_outputs=None, retain_graph: Optional[bool] = None, create_graph: bool = False, only_inputs: bool = True, allow_unused: bool = False) Tuple[Tensor, ...][source]

Helper method for computing torch.autograd.grad on storch tensors. Returns storch Tensors as well.

storch.storch.logsumexp(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor[source]
storch.storch.mean(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor[source]

Simply takes the mean of the tensor over the dimensions given. WARNING: This does NOT weight the different elements according to the plates. You will very likely want to call the reduce_plates method instead.

storch.storch.order_plates(plates: [<class 'storch.tensor.Plate'>], reverse=False)[source]

Topologically order the given plates. Uses Kahn’s algorithm.

storch.storch.reduce_plates(tensor: Union[Tensor, Tensor], plates: Optional[Union[List[Union[Plate, str]], Plate, str]] = None, detach_weights=True) Tensor[source]

Reduce the tensor along the given plates. This takes into account how different samples are weighted, and should nearly always be used instead of reducing plate dimensions using the mean or the sum. By default, this reduces all plates.

Parameters
  • tensor – Tensor to reduce

  • plates – Plates to reduce. If None, this reduces all plates (default). Can be a string, Plate, or list of string

  • Plates. (and) –

  • detach_weights – Whether to detach the weighting of the samples from the graph

Returns

The reduced tensor

storch.storch.sum(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor[source]

Simply sums the tensor over the dimensions given. WARNING: This does NOT weight the different elements according to the plates. You will very likely want to call the reduce_plates method instead.

storch.storch.variance(tensor: Union[Tensor, Tensor], variance_plate: Union[Plate, str], plates: Optional[Union[List[Union[Plate, str]], Plate, str]] = None, detach_weights=True) Tensor[source]

Compute the variance of the tensor along the plate dimensions. This takes into account how different samples are weighted.

Parameters
  • tensor – Tensor to compute variance over

  • plates – Plates to reduce.

  • detach_weights – Whether to detach the weighting of the samples from the graph

Returns

The variance of the tensor.

storch.util.get_distr_parameters(d: Distribution, filter_requires_grad=True) Dict[str, Tensor][source]
storch.util.has_backwards_path(output: ~storch.tensor.Tensor, inputs: [<class 'storch.tensor.Tensor'>]) [<class 'bool'>][source]

Returns true for each individual input if the gradient functions of the torch.Tensor underlying output is connected to the input tensor. This is only run once to compute the possibility of links between two storch.Tensor’s. The result is saved into the parent links on storch.Tensor’s. :param output: :param input: :param depth_first: Initialized to False as we are usually doing this only for small distances between tensors. :return:

storch.util.has_differentiable_path(output: Tensor, input: Tensor)[source]
storch.util.magic_box(l: Tensor)[source]

Implements the MagicBox operator from DiCE: The Infinitely Differentiable Monte-Carlo Estimator https://arxiv.org/abs/1802.05098 It returns 1 in the forward pass, but returns magic_box(l) cdot r in the backwards pass. This allows for any-order gradient estimation.

storch.util.print_graph(costs: [<class 'storch.tensor.CostTensor'>])[source]
storch.util.reduce_mean(tensor: ~torch.Tensor, keep_dims: [<class 'int'>])[source]
storch.util.rsample_gumbel(distr: Distribution, n: int) Tensor[source]
storch.util.rsample_gumbel_softmax(distr: Distribution, n: int, temperature: Tensor, straight_through: bool = False) Tensor[source]
storch.util.split(tensor: Tensor, plate: Plate, *, amt_slices: Optional[int] = None, slices: Optional[List[slice]] = None, create_plates=True) Tuple[Tensor, ...][source]

Splits the plate dimension on the tensor into several tensors and returns those tensors. Note: It removes the tensors from the computation graph and therefore should only be used when creating estimators, when logging or debugging, or if you know what you’re doing.

storch.util.tensor_stats(tensor: Tensor)[source]
storch.util.topological_sort(costs: [<class 'storch.tensor.CostTensor'>]) [<class 'storch.tensor.Tensor'>][source]

Implements reverse kahn’s algorithm :param costs: :return:

storch.util.walk_backward_graph(tensor: Tensor) Iterable[Tensor][source]