Storch package¶
Inference¶
Wrappers¶
- storch.wrappers.deterministic(fn: Optional[Callable] = None, **kwargs)[source]¶
Wraps the input function around a deterministic storch wrapper. This wrapper unwraps
Tensor
objects toTensor
objects, aligning the tensors according to the plates, then runs fn on the unwrapped Tensors.- Parameters
fn – Optional function to wrap. If None, this returns another wrapper that accepts a function that will be instantiated
kwargs. (by the given) –
unwrap – Set to False to prevent unwrapping
Tensor
objects.fn_args – List of non-keyword arguments to the wrapped function
fn_kwargs – Dictionary of keyword arguments to the wrapped function
unwrap – Whether to unwrap the arguments to their torch.Tensor counterpart (default: True)
align_tensors – Whether to automatically align the input arguments (default: True)
l_broadcast – Whether to automatically left-broadcast (default: True)
expand_plates – Instead of adding singleton dimensions on non-existent plates, this will
(default (Note that outputs are unflattened automatically.) – False) flatten_plates sets this to True automatically.
flatten_plates – Flattens the plate dimensions into a single batch dimension if set to true.
dimension. (This can be useful for functions that are written to only work for tensors with a single batch) –
(default – False)
dim – Replaces the dim input in fn_kwargs by the plate dimension corresponding to the given string (optional)
dims – Replaces the dims input in fn_kwargs by the plate dimensions corresponding to the given strings (optional)
self_wrapper – storch.Tensor that wraps a
- Returns
The wrapped function fn.
- Return type
Callable
- storch.wrappers.make_left_broadcastable(fn: Optional[Callable])[source]¶
Deterministic wrapper that is compatible with functions that are not by themselves left-broadcastable, such as
torch.nn.Conv2d()
. This function is on (N, C, H, W) and cannot deal with additional ‘independent’ dimensions on the left. To fix this, use make_left_broadcastable(Conv2d(16, 33, 3))
- storch.wrappers.reduce(fn, plates: Union[str, List[str]])[source]¶
Wraps the input function around a deterministic storch wrapper. This wrapper unwraps
Tensor
objects toTensor
objects, aligning the tensors according to the plates, then runs fn on the unwrapped Tensors. It will reduce the plates given by plates.- Parameters
fn (Callable) – Function to wrap.
- Returns
The wrapped function fn.
- Return type
Callable
Exceptions¶
Unique¶
This module is highly experimental
Utilities¶
- storch.storch.cat(*args, **kwargs)[source]¶
Version of
torch.cat()
that is compatible withstorch.Tensor
. Required becausetorch.Tensor.__torch_function__()
is not properly implemented fortorch.cat()
: https://github.com/pytorch/pytorch/issues/34294
- storch.storch.conditional_gumbel_rsample(hard_sample: Tensor, probs: Tensor, bernoulli: bool, temperature) Tensor [source]¶
Conditionally re-samples from the distribution given the hard sample. This samples z sim p(z|b), where b is the hard sample and p(z) is a gumbel distribution.
- storch.storch.expand_as(tensor: Union[Tensor, Tensor], expand_as: Union[Tensor, Tensor]) Union[Tensor, Tensor] [source]¶
- storch.storch.grad(outputs, inputs, grad_outputs=None, retain_graph: Optional[bool] = None, create_graph: bool = False, only_inputs: bool = True, allow_unused: bool = False) Tuple[Tensor, ...] [source]¶
Helper method for computing torch.autograd.grad on storch tensors. Returns storch Tensors as well.
- storch.storch.logsumexp(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor [source]¶
- storch.storch.mean(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor [source]¶
Simply takes the mean of the tensor over the dimensions given. WARNING: This does NOT weight the different elements according to the plates. You will very likely want to call the reduce_plates method instead.
- storch.storch.order_plates(plates: [<class 'storch.tensor.Plate'>], reverse=False)[source]¶
Topologically order the given plates. Uses Kahn’s algorithm.
- storch.storch.reduce_plates(tensor: Union[Tensor, Tensor], plates: Optional[Union[List[Union[Plate, str]], Plate, str]] = None, detach_weights=True) Tensor [source]¶
Reduce the tensor along the given plates. This takes into account how different samples are weighted, and should nearly always be used instead of reducing plate dimensions using the mean or the sum. By default, this reduces all plates.
- Parameters
tensor – Tensor to reduce
plates – Plates to reduce. If None, this reduces all plates (default). Can be a string, Plate, or list of string
Plates. (and) –
detach_weights – Whether to detach the weighting of the samples from the graph
- Returns
The reduced tensor
- storch.storch.sum(tensor: Tensor, dims: Union[List[Union[str, int, Plate]], str, int, Plate]) Tensor [source]¶
Simply sums the tensor over the dimensions given. WARNING: This does NOT weight the different elements according to the plates. You will very likely want to call the reduce_plates method instead.
- storch.storch.variance(tensor: Union[Tensor, Tensor], variance_plate: Union[Plate, str], plates: Optional[Union[List[Union[Plate, str]], Plate, str]] = None, detach_weights=True) Tensor [source]¶
Compute the variance of the tensor along the plate dimensions. This takes into account how different samples are weighted.
- Parameters
tensor – Tensor to compute variance over
plates – Plates to reduce.
detach_weights – Whether to detach the weighting of the samples from the graph
- Returns
The variance of the tensor.
- storch.util.get_distr_parameters(d: Distribution, filter_requires_grad=True) Dict[str, Tensor] [source]¶
- storch.util.has_backwards_path(output: ~storch.tensor.Tensor, inputs: [<class 'storch.tensor.Tensor'>]) [<class 'bool'>] [source]¶
Returns true for each individual input if the gradient functions of the torch.Tensor underlying output is connected to the input tensor. This is only run once to compute the possibility of links between two storch.Tensor’s. The result is saved into the parent links on storch.Tensor’s. :param output: :param input: :param depth_first: Initialized to False as we are usually doing this only for small distances between tensors. :return:
- storch.util.magic_box(l: Tensor)[source]¶
Implements the MagicBox operator from DiCE: The Infinitely Differentiable Monte-Carlo Estimator https://arxiv.org/abs/1802.05098 It returns 1 in the forward pass, but returns magic_box(l) cdot r in the backwards pass. This allows for any-order gradient estimation.
- storch.util.rsample_gumbel(distr: Distribution, n: int) Tensor [source]¶
- storch.util.rsample_gumbel_softmax(distr: Distribution, n: int, temperature: Tensor, straight_through: bool = False) Tensor [source]¶
- storch.util.split(tensor: Tensor, plate: Plate, *, amt_slices: Optional[int] = None, slices: Optional[List[slice]] = None, create_plates=True) Tuple[Tensor, ...] [source]¶
Splits the plate dimension on the tensor into several tensors and returns those tensors. Note: It removes the tensors from the computation graph and therefore should only be used when creating estimators, when logging or debugging, or if you know what you’re doing.