Source code for storch.wrappers

from __future__ import annotations

from copy import copy
from typing import Union, Any, Tuple, List, Optional, Dict, Callable

from torch.distributions import Distribution

import storch
import torch
from collections.abc import Iterable, Mapping
from functools import wraps
from storch.exceptions import IllegalStorchExposeError
from contextlib import contextmanager

_context_stochastic = False
_context_deterministic = 0
_stochastic_parents = []
_context_name = None
_plate_links = []
_ignore_wrap = False

# TODO: This is_iterable thing is a bit annoying: We really only want to unwrap them if they contain storch
#  Tensors, and then only for some types. Should rethink, maybe. Is unwrapping even necessary if the base torch methods
#  are all overriden? Maybe, see torch.cat?
[docs]def is_iterable(a: Any): return ( isinstance(a, Iterable) and not storch.is_tensor(a) and not isinstance(a, str) and not isinstance(a, torch.Storage) )
def _collect_parents_and_plates( a: Any, parents: [storch.Tensor], plates: [storch.Plate] ) -> int: if isinstance(a, storch.Tensor): parents.append(a) for plate in a.plates: if plate not in plates: plates.append(plate) return a.event_dims elif isinstance(a, Mapping): max_event_dim = 0 for _a in a.values(): max_event_dim = max( max_event_dim, _collect_parents_and_plates(_a, parents, plates) ) return max_event_dim elif is_iterable(a): max_event_dim = 0 for _a in a: max_event_dim = max( max_event_dim, _collect_parents_and_plates(_a, parents, plates) ) return max_event_dim return 0 def _unsqueeze_and_unwrap( a: Any, multi_dim_plates: [storch.Plate], align_tensors: bool, l_broadcast: bool, expand_plates: bool, flatten_plates: bool, event_dims: int, ): if isinstance(a, storch.Tensor): if not align_tensors: return a._tensor for plate in multi_dim_plates: # Used in storch.method.sampling.AncestralPlate a = plate.on_unwrap_tensor(a) tensor = a._tensor # Automatically **RIGHT** broadcast. Ensure each tensor has an equal amount of event dims by inserting dimensions to the right # TODO: What do we think about this design? # TODO: The storch.tensor._getitem_level == 0 check prevents right-broadcasting for __getitem__ and __setitem__... Seems hacky if l_broadcast and a.event_dims < event_dims: tensor = tensor[(...,) + (None,) * (event_dims - a.event_dims)] # It can be possible that the ordering of the plates does not align with the ordering of the inputs. # This part corrects this. amt_recognized = 0 links: [storch.Plate] = a.multi_dim_plates() for i, plate in enumerate(multi_dim_plates): if plate in links: if plate != links[amt_recognized]: # The plate is also in the tensor, but not in the ordering expected. So switch that ordering j = links.index(plate) tensor = tensor.transpose(j, amt_recognized) links[amt_recognized], links[j] = links[j], links[amt_recognized] amt_recognized += 1 # Add singleton dimensions on missing plates plate_dims = [] for i, plate in enumerate(multi_dim_plates): if plate not in a.plates: tensor = tensor.unsqueeze(i) plate_dims.append(plate.n) else: # Make sure to use a's plate size here. It's actually possible they are different in ancestral plates! plate_dims.append(tensor.shape[i]) # Optionally expand the singleton dimensions to the plate size if expand_plates: # TODO: Can maybe be optimized by only running this if the shape is different tensor = tensor.expand(tuple(plate_dims) + tensor.shape[len(plate_dims) :]) # Optionally flatten the plate dimensions to a single batch dimension if flatten_plates: assert expand_plates tensor = tensor.reshape((-1,) + tensor.shape[len(plate_dims) :]) return tensor elif isinstance(a, Mapping): d = {} for k, _a in a.items(): d[k] = _unsqueeze_and_unwrap( _a, multi_dim_plates, align_tensors, l_broadcast, expand_plates, flatten_plates, event_dims, ) return d elif is_iterable(a): l = [] for _a in a: l.append( _unsqueeze_and_unwrap( _a, multi_dim_plates, align_tensors, l_broadcast, expand_plates, flatten_plates, event_dims, ) ) if isinstance(a, tuple): return tuple(l) return l elif isinstance(a, Distribution): from storch.util import get_distr_parameters params = get_distr_parameters(a, False) params_unsq = _unsqueeze_and_unwrap(params, multi_dim_plates, align_tensors, l_broadcast, expand_plates, flatten_plates, event_dims) try: if 'logits' in params and 'probs' in params: # Discrete distributions don't like it if you pass both probs and logits _a = a.__class__(logits=params_unsq['logits']) else: # Attempt to instantiate a copy of the distribution using the unsqueezed parameters _a = a.__class__(**params_unsq) except Exception as e: return a return _a else: return a def _prepare_args( fn_args, fn_kwargs, unwrap=True, align_tensors=True, l_broadcast=True, expand_plates=False, flatten_plates=False, dim: Optional[str] = None, dims: Optional[Union[str, List[str]]] = None, ) -> (List, Dict, [storch.Tensor], [storch.Plate]): """ Prepares the input arguments of the wrapped function: - Unwrap the input arguments from storch.Tensors to normal torch.Tensors so they can be used in any torch function - Align plate dimensions for automatic broadcasting - Add (singleton) plate dimensions for plates that are not present - Right-broadcast event dimensions for automatic broadcasting - Superclasses of Plate specific input handling :param fn_args: List of non-keyword arguments to the wrapped function :param fn_kwargs: Dictionary of keyword arguments to the wrapped function :param unwrap: Whether to unwrap the arguments to their torch.Tensor counterpart (default: True) :param align_tensors: Whether to automatically align the input arguments (default: True) :param l_broadcast: Whether to automatically left-broadcast (default: True) :param expand_plates: Instead of adding singleton dimensions on non-existent plates, this will add the plate size itself (default: False) flatten_plates sets this to True automatically. :param flatten_plates: Flattens the plate dimensions into a single batch dimension if set to true. This can be useful for functions that are written to only work for tensors with a single batch dimension. Note that outputs are unflattened automatically. (default: False) :param dim: Replaces the dim input in fn_kwargs by the plate dimension corresponding to the given string (optional) :param dims: Replaces the dims input in fn_kwargs by the plate dimensions corresponding to the given strings (optional) :param self_wrapper: storch.Tensor that wraps a :return: Handled non-keyword arguments, handled keyword arguments, list of parents, list of plates """ parents: [storch.Tensor] = [] plates: [storch.Plate] = [] max_event_dim = max( # Collect parent tensors and plates _collect_parents_and_plates(fn_args, parents, plates), _collect_parents_and_plates(fn_kwargs, parents, plates), ) # Allow plates to filter themselves from being collected. plates = list(filter(lambda p: p.on_collecting_args(plates), plates)) # Get the list of plates with size larger than 1 for the unsqueezing of tensors multi_dim_plates = [] for plate in plates: if plate.n > 1: multi_dim_plates.append(plate) if dim: for i, plate in enumerate(multi_dim_plates): if dim == plate.name: i_dim = i break fn_kwargs["dim"] = i_dim if dims: dimz = [] for dim in dims: for i, plate in enumerate(multi_dim_plates): if dim == plate.name: dimz.append(i_dim) break raise ValueError("Missing plate dimension" + dim) fn_kwargs["dims"] = dimz if unwrap: expand_plates = expand_plates or flatten_plates # Unsqueeze and align batched dimensions so that batching works easily. unsqueezed_args = [] for t in fn_args: unsqueezed_args.append( _unsqueeze_and_unwrap( t, multi_dim_plates, align_tensors, l_broadcast, expand_plates, flatten_plates, max_event_dim, ) ) unsqueezed_kwargs = {} for k, v in fn_kwargs.items(): unsqueezed_kwargs[k] = _unsqueeze_and_unwrap( v, multi_dim_plates, align_tensors, l_broadcast, expand_plates, flatten_plates, max_event_dim, ) return unsqueezed_args, unsqueezed_kwargs, parents, plates return fn_args, fn_kwargs, parents, plates def _prepare_outputs_det( o: Any, parents: [storch.Tensor], plates: [storch.Plate], name: str, index: int, unflatten_plates, ): if o is None: return None, index if isinstance(o, storch.Tensor): if o.stochastic: raise RuntimeError( "Creation of stochastic storch Tensor within deterministic context" ) # TODO: Does this require shape checking? Parent/Plate checking? # This might be very buggy, hard to figure out how to merge these concepts... Try to prevent creating # storch.Tensors within deterministic contexts. new_plates = o.plates.copy() for plate in reversed(plates): plate_found = False for i, other_plate in enumerate(new_plates): if plate.name == other_plate.name: plate_found = True if hasattr(plate, "variable_index") : assert hasattr(other_plate, "variable_index") if plate.variable_index > other_plate.variable_index: new_plates[i] = plate if not plate_found: new_plates.insert(0, plate) new_parents = parents.copy() new_parents.append(o) t = storch.Tensor(o._tensor, parents, new_plates, name=name + str(index)) return t, index + 1 if isinstance(o, torch.Tensor): # Explicitly _not_ a storch.Tensor if unflatten_plates: plate_dims = tuple([plate.n for plate in plates if plate.n > 1]) o = o.reshape(plate_dims + o.shape[1:]) t = storch.Tensor(o, parents, plates, name=name + str(index)) return t, index + 1 if is_iterable(o): outputs = [] for _o in o: t, index = _prepare_outputs_det( _o, parents, plates, name, index, unflatten_plates=unflatten_plates ) outputs.append(t) if isinstance(o, tuple): return tuple(outputs), index return outputs, index raise NotImplementedError( "Handling of other types of return values is currently not implemented: ", o ) def _handle_deterministic( fn, fn_args, fn_kwargs, reduce_plates: Optional[Union[str, List[str]]] = None, flatten_plates: bool = False, **wrapper_kwargs ): if storch.wrappers._context_stochastic: raise NotImplementedError( "It is currently not allowed to open a deterministic context in a stochastic context" ) # TODO check if we can re-add this # if storch.wrappers._context_deterministic > 0: # if is_cost: # raise RuntimeError("Cannot call storch.cost from within a deterministic context.") # TODO: This is currently uncommented and it will in fact unwrap. This was required because it was, eg, # possible to open a deterministic context, passing distributions with storch.Tensors as parameters, # then doing computations on these parameters. This is because these storch.Tensors will not be unwrapped # in the deterministic context as the unwrapping only considers lists. # # We are already in a deterministic context, no need to wrap or unwrap as only the outer dependencies matter # return fn(*args, **kwargs) new_fn_args, new_fn_kwargs, parents, plates = _prepare_args( fn_args, fn_kwargs, flatten_plates=flatten_plates, **wrapper_kwargs ) if not parents: return fn(*fn_args, **fn_kwargs) args = new_fn_args kwargs = new_fn_kwargs storch.wrappers._context_deterministic += 1 try: outputs = fn(*args, **kwargs) finally: storch.wrappers._context_deterministic -= 1 if storch.wrappers._ignore_wrap: return outputs if reduce_plates: if isinstance(reduce_plates, str): reduce_plates = [reduce_plates] plates = [p for p in plates if p.name not in reduce_plates] outputs = _prepare_outputs_det( outputs, parents, plates, fn.__name__, 1, unflatten_plates=flatten_plates )[0] return outputs def _deterministic( fn, reduce_plates: Optional[Union[str, List[str]]] = None, **wrapper_kwargs ): @wraps(fn) def wrapper(*args, **kwargs): nonlocal reduce_plates return _handle_deterministic(fn, args, kwargs, reduce_plates, **wrapper_kwargs) return wrapper
[docs]def deterministic(fn: Optional[Callable] = None, **kwargs): """ Wraps the input function around a deterministic storch wrapper. This wrapper unwraps :class:`~storch.Tensor` objects to :class:`~torch.Tensor` objects, aligning the tensors according to the plates, then runs `fn` on the unwrapped Tensors. Args: fn: Optional function to wrap. If None, this returns another wrapper that accepts a function that will be instantiated by the given kwargs. unwrap: Set to False to prevent unwrapping :class:`~storch.Tensor` objects. fn_args: List of non-keyword arguments to the wrapped function fn_kwargs: Dictionary of keyword arguments to the wrapped function unwrap: Whether to unwrap the arguments to their torch.Tensor counterpart (default: True) align_tensors: Whether to automatically align the input arguments (default: True) l_broadcast: Whether to automatically left-broadcast (default: True) expand_plates: Instead of adding singleton dimensions on non-existent plates, this will add the plate size itself (default: False) flatten_plates sets this to True automatically. flatten_plates: Flattens the plate dimensions into a single batch dimension if set to true. This can be useful for functions that are written to only work for tensors with a single batch dimension. Note that outputs are unflattened automatically. (default: False) dim: Replaces the dim input in fn_kwargs by the plate dimension corresponding to the given string (optional) dims: Replaces the dims input in fn_kwargs by the plate dimensions corresponding to the given strings (optional) self_wrapper: storch.Tensor that wraps a Returns: Callable: The wrapped function `fn`. """ if fn: return _deterministic(fn, **kwargs) return lambda _f: _deterministic(_f, **kwargs)
[docs]def make_left_broadcastable(fn: Optional[Callable]): """ Deterministic wrapper that is compatible with functions that are not by themselves left-broadcastable, such as :func:`torch.nn.Conv2d`. This function is on (N, C, H, W) and cannot deal with additional 'independent' dimensions on the left. To fix this, use `make_left_broadcastable(Conv2d(16, 33, 3))` """ return deterministic(fn, flatten_plates=True)
[docs]def reduce(fn, plates: Union[str, List[str]]): """ Wraps the input function around a deterministic storch wrapper. This wrapper unwraps :class:`~storch.Tensor` objects to :class:`~torch.Tensor` objects, aligning the tensors according to the plates, then runs `fn` on the unwrapped Tensors. It will reduce the plates given by `plates`. Args: fn (Callable): Function to wrap. Returns: Callable: The wrapped function `fn`. """ if storch._debug: print("Reducing plates", plates) return _deterministic(fn, reduce_plates=plates)
def _self_deterministic(fn, self: storch.Tensor): fn = deterministic(fn) @wraps(fn) def wrapper(*args, **kwargs): # Inserts the self object at the beginning of the passed arguments. In essence, it "fakes" the self reference. args = list(args) args.insert(0, self) return fn(*args, **kwargs) return wrapper def _process_stochastic( output: torch.Tensor, parents: [storch.Tensor], plates: [storch.Plate] ): if isinstance(output, storch.Tensor): if not output.stochastic: # TODO: Calls _add_parents so something is going wrong here # The Tensor was created by calling @deterministic within a stochastic context. # This means that we have to conservatively assume it is dependent on the parents output._add_parents(storch.wrappers._stochastic_parents) return output if isinstance(output, torch.Tensor): t = storch.Tensor(output, parents, plates) return t else: raise TypeError( "All outputs of functions wrapped in @storch.stochastic " "should be Tensors. At " + str(output) )
[docs]def stochastic(fn): """ Applies `fn` to the `inputs`. `fn` should return one or multiple `storch.Tensor`s. `fn` should not call `storch.stochastic` or `storch.deterministic`. `inputs` can include `storch.Tensor`s. :param fn: :return: """ @wraps(fn) def wrapper(*args, **kwargs): if ( storch.wrappers._context_stochastic or storch.wrappers._context_deterministic > 0 ): raise RuntimeError( "Cannot call storch.stochastic from within a stochastic or deterministic context." ) # Save the parents args, kwargs, parents, plates = _prepare_args(*args, **kwargs) storch.wrappers._plate_links = plates storch.wrappers._stochastic_parents = parents storch.wrappers._context_stochastic = True storch.wrappers._context_name = fn.__name__ try: outputs = fn(*args, **kwargs) finally: storch.wrappers._plate_links = [] storch.wrappers._stochastic_parents = [] storch.wrappers._context_stochastic = False storch.wrappers._context_name = None # Add parents to the outputs if is_iterable(outputs): processed_outputs = [] for o in outputs: processed_outputs.append(_process_stochastic(o, parents, plates)) else: processed_outputs = _process_stochastic(outputs, parents, plates) return processed_outputs return wrapper
def _exception_wrapper(fn): def wrapper(*args, **kwargs): for a in args: if isinstance(a, storch.Tensor): raise IllegalStorchExposeError( "It is not allowed to call this method using storch.Tensor, likely " "because it exposes its wrapped tensor to Python." ) return fn(*args, **kwargs) return wrapper def _unpack_wrapper(fn, self: Optional[storch.Tensor] = None): @wraps(fn) def wrapper(*args, **kwargs): if self: args = list(args) args.insert(0, self) new_args = [] for a in args: if isinstance(a, storch.Tensor): new_args.append(a._tensor) else: new_args.append(a) return fn(*tuple(new_args), **kwargs) return wrapper
[docs]@contextmanager def ignore_wrapping(): storch.wrappers._ignore_wrap = True yield storch.wrappers._ignore_wrap = False