from __future__ import annotations
from queue import Queue
import torch
import storch
from torch.distributions import Distribution
from collections import deque
from typing import List, Iterable, Any, Callable, Iterator, Dict, Tuple, Deque
import builtins
from itertools import product
from typing import Optional
from storch.exceptions import IllegalStorchExposeError
from storch.excluded_init import (
exception_methods,
excluded_methods,
unwrap_only_methods,
expand_methods,
)
# from storch.typing import BatchTensor
class Plate:
def __init__(
self,
name: str,
n: int,
parents: List[Plate],
weight: Optional[storch.Tensor] = None,
):
self.weight = weight
if weight is None:
self.weight = torch.tensor(1.0 / n)
self.name = name
self.n = n
self.parents = parents
def __eq__(self, other) -> bool:
if not isinstance(other, Plate):
return False
if self.name != other.name:
return False
if self.n != other.n:
# TODO: This should maybe return an error...?
return False
if isinstance(self.weight, storch.Tensor):
if not isinstance(other.weight, storch.Tensor):
return False
if self.weight._tensor is other.weight._tensor:
return True
if self.weight.shape != other.weight.shape:
return False
return self.weight._tensor.__eq__(other.weight._tensor).all()
if isinstance(other.weight, storch.Tensor):
return False
# Neither of the weights are Tensors, so the weights must be equal as self.n==other.n
return True
def __str__(self):
return self.name + ", " + str(self.n)
def __repr__(self):
return (
"("
+ self.name.__repr__()
+ ", "
+ self.n.__repr__()
+ ", "
+ self.weight.__repr__()
+ ")"
)
def reduce(self, tensor: storch.Tensor, detach_weights=True):
plate_weighting = self.weight
if detach_weights:
plate_weighting = self.weight.detach()
if self.n == 1:
return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor)
# Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean)
elif plate_weighting.ndim == 0:
return storch.sum(tensor, self) * plate_weighting
# Case: There is a weight for each plate which is not dependent on the other batch dimensions
elif plate_weighting.ndim == 1:
index = tensor.get_plate_dim_index(self.name)
plate_weighting = plate_weighting[
(...,) + (None,) * (tensor.ndim - index - 1)
]
weighted_tensor = tensor * plate_weighting
return storch.sum(weighted_tensor, self)
# Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor
else:
for parent_plate in self.parents:
if parent_plate not in tensor.plates:
raise ValueError(
"Plate missing when reducing tensor: " + parent_plate.name
)
weighted_tensor = tensor * plate_weighting
return storch.sum(weighted_tensor, self)
def on_collecting_args(self, plates: [Plate]) -> bool:
"""
Gets called after a wrapper collected plates from its input arguments.
Args:
plates ([Plate]): All collected plates
Returns:
bool: True if this plate should remain in the collected plates.
"""
for plate in plates:
if plate.name == self.name and plate != self:
if not plate.on_duplicate_plate(self):
return False
return True
def on_duplicate_plate(self, plate: Plate) -> bool:
raise ValueError("")
def on_unwrap_tensor(self, tensor: storch.Tensor) -> storch.Tensor:
"""
Gets called whenever the given tensor is being unwrapped and unsqueezed for batch use.
Args:
tensor (storch.Tensor): The input tensor that is being unwrapped
Returns:
storch.Tensor: The tensor that will be unwrapped and unsqueezed in the future. Can be a modification of the input tensor.
"""
return tensor
def index_in(self, plates: List[Plate]) -> int:
return plates.index(self)
def is_in(self, plates: Iterable[Plate]) -> bool:
return self in plates
class Tensor:
"""
A :class:`storch.Tensor` is a wrapper around a :class:`torch.Tensor` that acts like a normal :class:`torch.Tensor`
with some restrictions and some extra data.
By design, :class:`storch.Tensor` cannot expose the wrapped :class:`torch.Tensor` to regular Python control flow.
For example, using a :class:`storch.Tensor` inside an ``if`` condition or in a ``for`` loop will throw an
:class:`~storch.exceptions.IllegalStorchExposeError`. This is done because a node could be dependent on a Tensor that is used as
a conditional to branch between different computation paths. However, Python control flow will not register dependencies
between nodes in the computation graph.
The underlying :class:`torch.Tensor` can be unwrapped in two ways. The safe way is using the :func:`.deterministic`
wrapper, which safely unwraps the :class:`storch.Tensor` and runs the function on the unwrapped :class:`torch.Tensor`.
Note that all ``torch`` methods are automatically wrapped using :func:`.deterministic` when an input argument
is :class:`storch.Tensor`.
The unsafe way to unwrap the tensor is to access :attr:`storch.Tensor._tensor`. This should only be use when one is
sure this will not introduce missing dependency links.
Args:
tensor (torch.Tensor): The tensor to wrap. The leftmost dimensions should correspond to the sizes of ``plates``
that are larger than 1.
parents ([storch.Tensor]): The parents of this Tensor. Parents represent the incoming links in stochastic
computation graphs.
plates ([storch.Plate]): The plates of this Tensor. Plates contain information about the sampling procedure and
dependencies of this Tensor with respect to earlier samples.
name (Optional[str]): The name of this Tensor.
"""
def __init__(
self,
tensor: torch.Tensor,
parents: [Tensor],
plates: [Plate],
name: Optional[str] = None,
):
if isinstance(tensor, Tensor):
raise TypeError(
"storch.Tensors should be constructed with torch.Tensors, not other storch.Tensors."
)
plate_names = set()
batch_dims = 0
# Check whether this tensor does not violate the constraints imposed by the given batch_links
for plate in plates:
if plate.name in plate_names:
raise ValueError(
"Plates contain two instances of same plate "
+ plate.name
+ ". This can be caused by different samples with the same name using a different amount of "
+ "samples n or different weighting of the samples. Make sure that these samples use the same number of samples."
)
plate_names.add(plate.name)
# plate length is 1. Ignore this dimension, as singleton dimensions should not exist.
if plate.n == 1:
continue
if len(tensor.shape) <= batch_dims:
raise ValueError(
"Got an input tensor with too few dimensions. We expected "
+ str(len(plates))
+ " plate dimensions. Instead, we found only "
+ str(len(tensor.shape))
+ " dimensions. Violated at dimension "
+ str(batch_dims)
)
elif not tensor.shape[batch_dims] == plate.n:
raise ValueError(
"Storch Tensors should take into account their surrounding plates. Violated at dimension "
+ str(batch_dims)
+ " and plate "
+ plate.name
+ " with size "
+ str(plate.n)
+ ". "
"Instead, it was "
+ str(tensor.shape[batch_dims])
+ ". Batch links: "
+ str(plates)
+ " Tensor shape: "
+ str(tensor.shape)
)
batch_dims += 1
self._name = name
self._tensor = tensor
self._parents: List[Tuple[Tensor, bool]] = []
self._cleaned = False
# DISCONTINUED FOR NOW BECAUSE OF PERFORMANCE OVERHEAD
# differentiable_links = has_backwards_path(self, parents)
for i, p in enumerate(parents):
# TODO: Should I re-add this?
# if p.is_cost:
# raise ValueError("Cost nodes cannot have children.")
# TODO: DIFFERENTIABLE LINKS MANUALLY SET TO FALSE. THIS MIGHT CAUSE BUGS IN THE FUTURE
self._parents.append((p, False))
p._children.append((self, False))
self._children = []
self.plate_dims = batch_dims
self.event_shape = tensor.shape[batch_dims:]
self.event_dims = len(self.event_shape)
self.plates = plates
@classmethod
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None) -> Callable:
"""
Called whenever a torch.* or torch.nn.functional.* method is being called on a storch.Tensor. This wraps
that method in the deterministic wrapper to properly handle all input arguments and outputs.
"""
if kwargs is None:
kwargs = {}
func_name = func.__name__
if func_name in exception_methods:
raise IllegalStorchExposeError(
"Calling method " + func_name + " with storch tensors is not allowed."
)
if func_name in excluded_methods:
return func(*args, **kwargs)
if func_name in expand_methods:
# Automatically expand empty plate dimensions. This is necessary for some loss functions, which
# assume both inputs have exactly the same elements.
return storch.wrappers._handle_deterministic(
func, args, kwargs, expand_plates=True
)
# if func_name in unwrap_only_methods:
# return storch.wrappers._unpack_wrapper(func)(*args, *kwargs)
return storch.wrappers._handle_deterministic(func, args, kwargs)
def __getattr__(self, item) -> Any:
"""
Called whenever an attribute is called on a storch.Tensor object that is not directly implemented by storch.Tensor.
It defers it to the underlying torch.Tensor. If it is a callable (ie, torch.Tensor implements a function
with the name item), it will wrap this callable with a deterministic wrapper.
TODO: This should probably filter the methods
"""
attr = getattr(torch.Tensor, item)
if callable(attr):
func_name = attr.__name__
if func_name in exception_methods:
raise IllegalStorchExposeError(
"Calling method "
+ func_name
+ " with storch tensors is not allowed."
)
if func_name in excluded_methods:
return attr
# if func_name in unwrap_only_methods:
# return storch.wrappers._unpack_wrapper(attr, self=self)
return storch.wrappers._self_deterministic(attr, self)
@property
def name(self) -> str:
return self._name
@property
def is_sparse(self) -> bool:
"""
Returns: True if the underlying tensor is sparse.
"""
return self._tensor.is_sparse
def __str__(self) -> str:
t = (
(self.name + ": " if self.name else "") + "Stochastic"
if self.stochastic
else ("Cost" if self.is_cost else "Deterministic")
)
return t + " " + str(self._tensor) + " Batch links: " + str(self.plates)
def __repr__(self) -> str:
return f"[{repr(self.name)}, {repr(self._tensor)}, {repr(self.plates)}"
def __hash__(self) -> int:
return object.__hash__(self)
@storch.deterministic
def __eq__(self, other) -> bool:
return self.__eq__(other)
@storch.deterministic(l_broadcast=False)
def __getitem__(self, index):
return self.__getitem__(index)
@storch.deterministic(l_broadcast=False)
def __setitem__(self, index, value):
return self.__setitem__(index, value)
def _walk_backwards(
self,
expand_fn: Callable[[Tensor], Iterator[Tuple[Tensor, bool]]],
depth_first=True,
reverse=False, # Only supported for breadth-first
only_differentiable=False,
repeat_visited=False,
walk_fn=lambda x: x,
) -> Iterator[Tensor]:
visited = set()
visited_ordered = []
if depth_first:
S = [self]
while S:
v = S.pop()
if repeat_visited or not v.is_in(visited):
yield walk_fn(v)
visited.add(v)
for w, d in expand_fn(v):
if d or not only_differentiable:
S.append(w)
else:
queue: Deque[Tensor] = deque()
visited.add(self)
queue.append(self)
while queue:
v = queue.popleft()
if reverse:
visited_ordered.append(v)
else:
yield walk_fn(v)
for w, d in expand_fn(v):
if (repeat_visited or not w.is_in(visited)) and (
d or not only_differentiable
):
visited.add(w)
queue.append(w)
if reverse:
for v in reversed(visited_ordered):
yield v
def walk_parents(
self,
depth_first=True,
reverse=False,
only_differentiable=False,
repeat_visited=False,
walk_fn=lambda x: x,
) -> Iterator[Tensor]:
"""
Searches through the parents of this Tensor in the stochastic computation graph.
Args:
depth_first: True to use depth first, otherwise breadth first.
reverse: Reverse the order: If true, instead of first returning the immediate parents, return
the parents furthest up, working towards the immediate parents.
Currently only supported for breadth-first search.
only_differentiable: True to only walk over edges that are differentiable
repeat_visited:
walk_fn: Optional function on :class:`storch.Tensor` that manipulates the nodes found.
Returns:
Iterator of type that is equal to the output type of ``walk_fn``.
"""
return self._walk_backwards(
lambda p: p._parents,
depth_first,
reverse,
only_differentiable,
repeat_visited,
walk_fn,
)
def walk_children (
self,
depth_first=True,
reverse=False,
only_differentiable=False,
repeat_visited=False,
walk_fn=lambda x: x,
) -> Iterator[Tensor]:
"""
Searches through the children of this Tensor in the stochastic computation graph.
Args:
depth_first: True to use depth first, otherwise breadth first.
only_differentiable: True to only walk over edges that are differentiable
repeat_visited:
walk_fn: Optional function on :class:`storch.Tensor` that manipulates the nodes found.
Returns:
Iterator of type that is equal to the output type of ``walk_fn``.
"""
return self._walk_backwards(
lambda p: p._children,
depth_first,
reverse,
only_differentiable,
repeat_visited,
walk_fn,
)
def _clean(self) -> None:
"""
Cleans up :attr:`_children` and :attr:`_parents` for all nodes in the subgraph of this node (depth first)
"""
if self._cleaned:
return
self._cleaned = True
for (node, _) in self._children:
node._clean()
for (node, _) in self._parents:
node._clean()
self._children = []
self._parents = []
def detach_tensor(self) -> storch.Tensor:
"""
Returns: A :class:`storch.Tensor` that is removed from PyTorch's differention graph.
However, the tensor will remain present on the stochastic computation graph.
"""
return self._tensor.detach()
@property
def stochastic(self) -> bool:
"""
Returns:
bool: True if this is a stochastic node in the stochastic computation graph, False otherwise.
"""
return False
@property
def is_cost(self) -> bool:
"""
Returns:
bool: True if this is a cost node in the stochastic computation graph, False otherwise.
"""
return False
@property
def parents(self) -> List[Tensor]:
return list(map(lambda p: p[0], self._parents))
@property
def requires_grad(self) -> bool:
return self._tensor.requires_grad
@property
def plate_shape(self) -> torch.Size:
return self._tensor.shape[: self.plate_dims]
def size(self, *args) -> torch.Size:
return self._tensor.size(*args)
@property
def shape(self) -> torch.Size:
return self._tensor.size()
def is_cuda(self):
return self._tensor.is_cuda
@property
def dtype(self):
return self._tensor.dtype
@property
def layout(self):
return self._tensor.layout
@property
def device(self):
return self._tensor.device
@property
def grad(self):
return self._tensor.grad
def dim(self):
return self._tensor.dim()
def ndimension(self):
return self._tensor.ndimension()
@property
def ndim(self):
return self._tensor.ndim
def register_hook(self, hook: Callable) -> Any:
return self._tensor.register_hook(hook)
@property
def event_dim_indices(self):
return range(self.plate_dims, self._tensor.dim())
def get_plate(self, plate_name: str) -> Plate:
for plate in self.plates:
if plate.name == plate_name:
return plate
raise IndexError("Tensor has no such plate: " + plate_name + ".")
def get_plate_dim_index(self, plate_name: str) -> int:
for i, plate in enumerate(self.multi_dim_plates()):
if plate.name == plate_name:
return i
raise IndexError(
"Tensor has no such plate: "
+ plate_name
+ ". Alternatively, the dimension of this batch is 1."
)
def iterate_plate_indices(self) -> Iterable[List[int]]:
ranges = list(map(lambda a: list(range(a)), self.plate_shape))
return product(*ranges)
def multi_dim_plates(self) -> List[Plate]:
return list(filter(lambda p: p.n > 1, self.plates))
def backward(
self,
gradient: Optional[Tensor] = None,
keep_graph: bool = False,
create_graph: bool = False,
retain_graph: bool = False,
) -> None:
raise NotImplementedError(
"Cannot call .backward on storch.Tensor. Instead, register cost nodes using "
"storch.add_cost, then use storch.backward()."
)
def is_in(self, tensors: Iterable[Tensor]) -> bool:
for tensor in tensors:
if tensor is self:
return True
return False
# region OperatorOverloads
def __len__(self) -> int:
return self._tensor.__len__()
def __index__(self) -> int:
raise IllegalStorchExposeError("Cannot use storch tensors as index.")
@storch.deterministic
def eq(self, other) -> bool:
return self.eq(other)
def __getstate__(self):
raise NotImplementedError(
"Pickle is currently not implemented for storch tensors."
)
def __setstate__(self, state):
raise NotImplementedError(
"Pickle is currently not implemented for storch tensors."
)
def __bool__(self):
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to boolean. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as a boolean."
)
def __float__(self):
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to float. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as a float."
)
def __int__(self):
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to int. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as an int."
)
def __long__(self):
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to long. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as a long."
)
def __nonzero__(self) -> builtins.bool:
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to boolean. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as a boolean."
)
def __array__(self):
self.numpy()
def __array_wrap__(self):
self.numpy()
def numpy(self):
raise IllegalStorchExposeError(
"It is not allowed to convert storch tensors to numpy arrays. Make sure to unwrap "
"storch tensors to normal torch tensor to use this tensor as a np.array."
)
def __contains__(self, item):
raise IllegalStorchExposeError(
"It is not allowed to expose storch tensors via in statements."
)
def __deepcopy__(self, memodict={}):
raise NotImplementedError(
"There is currently no deep copying implementation for storch Tensors."
)
def __iter__(self):
# TODO: This recognizes storch.Tensor as Iterable, even though it's not implemented.
raise NotImplementedError("Cannot currently iterate over storch Tensors.")
def detach_(self) -> Tensor:
raise NotImplementedError("In place detach is not allowed on storch tensors.")
def __add__(self, other):
return torch.add(self, other)
def __radd__(self, other):
return torch.add(other, self)
def __sub__(self, other):
return torch.sub(self, other)
def __rsub__(self, other):
return torch.sub(other, self)
def __mul__(self, other):
return torch.mul(self, other)
def __rmul__(self, other):
return torch.mul(other, self)
def __matmul__(self, other):
return torch.matmul(self, other)
def __rmatmul__(self, other):
return torch.matmul(other, self)
def __pow__(self, other):
return torch.pow(self, other)
def __rpow__(self, other):
return torch.pow(other, self)
def __div__(self, other):
return torch.div(self, other)
def __rdiv__(self, other):
return torch.div(other, self)
def __mod__(self, other):
return torch.remainder(self, other)
def __rmod__(self, other):
return torch.remainder(other, self)
def __truediv__(self, other):
return torch.true_divide(self, other)
def __rtruediv__(self, other):
return torch.true_divide(other, self)
def __floordiv__(self, other):
return torch.floor_divide(self, other)
def __rfloordiv__(self, other):
return torch.floor_divide(other, self)
def __abs__(self):
return torch.abs(self)
def __and__(self, other):
return torch.logical_and(self, other)
def __rand__(self, other):
return torch.logical_and(other, self)
def __ge__(self, other):
return torch.ge(self, other)
def __gt__(self, other):
return torch.gt(self, other)
@storch.deterministic
def __invert__(self):
return self.__invert__()
def __le__(self, other):
return torch.le(self, other)
@storch.deterministic
def __lshift__(self, other):
return self.__lshift__(other)
@storch.deterministic
def __lshift__(self, other):
return other.__lshift__(self)
def __lt__(self, other):
return torch.lt(self, other)
def ne(self, other):
return torch.ne(self, other)
def __neg__(self):
return torch.neg(self)
def __or__(self, other):
return torch.logical_or(self, other)
def __ror__(self, other):
return torch.logical_or(other, self)
def __pos__(self):
# TODO: Is this correct?
return self
@storch.deterministic
def __rshift__(self, other):
return self.__rshift__(other)
@storch.deterministic
def __rrshift__(self, other):
return other.__rshift__(self)
def __xor__(self, other):
return torch.logical_xor(self, other)
def __rxor__(self, other):
return torch.logical_xor(other, self)
# endregion
class CostTensor(Tensor):
def __init__(self, tensor: torch.Tensor, parents, plate_links: [Plate], name: str):
super().__init__(tensor, parents, plate_links, name)
@property
def is_cost(self) -> bool:
return True
[docs]class IndependentTensor(Tensor):
"""
Used to denote independencies on a Tensor. This could for example be the minibatch dimension. The first dimension
of the input tensor is taken to be independent and added as a batch dimension to the storch system.
"""
def __init__(
self,
tensor: torch.Tensor,
parents: [Tensor],
plates: [Plate],
tensor_name: str,
plate_name: str,
weight: Optional[storch.Tensor],
):
n = tensor.shape[0]
for plate in plates:
if plate.name == plate_name:
raise ValueError(
"Cannot create independent tensor with name "
+ plate_name
+ ". A parent sample has already used"
" this name. Use a different name for this independent dimension."
)
plates.insert(0, Plate(plate_name, n, plates.copy(), weight))
super().__init__(tensor, parents, plates, tensor_name)
self.n = n
# TODO: Should IndependentTensors be stochastic? Sometimes, like if it is denoting a minibatch, making them
# stochastic seems like it is correct. Are there other cases?
[docs] def stochastic(self) -> bool:
return True
class StochasticTensor(Tensor):
"""
A :class:`storch.Tensor` that represents a stochastic node in the stochastic computation graph.
Args:
n (int): The size of the plate dimension created by this stochastic node.
distribution: The distribution of this stochastic node.
requires_grad (bool): True if we are interested in the gradients with respect to the parameters of the distribution of
this stochastic node.
"""
# TODO: Copy original tensor to make sure it cannot change using inplace
def __init__(
self,
tensor: torch.Tensor,
parents: [Tensor],
plates: [Plate],
name: str,
n: int,
distribution: Distribution,
requires_grad: bool,
method: Optional[storch.method.Method] = None,
):
self.distribution: Distribution = distribution
super().__init__(tensor, parents, plates, name)
self._requires_grad = requires_grad
self.n = n
self.method = method
self.param_grads = {}
self._grad = None
self._clean_hooks = []
self._remove_handles = []
@property
def stochastic(self) -> bool:
return True
@property
# TODO: Should not manually override it like this. The stochastic "requires_grad" should be a different method, so
# that the meaning of requires_grad is consistent everywhere
def requires_grad(self) -> bool:
return self._requires_grad
@property
def grad(self) -> Dict[str, storch.Tensor]:
return self.param_grads
def _set_method(self, method: storch.method.Method):
self.method = method
def _clean(self) -> None:
new_param_grads = {}
for name, grad in self.param_grads.items():
# In case higher-order derivatives are stored, remove these from the graph.
new_param_grads[name] = grad.detach()
for clean_hook in self._clean_hooks:
clean_hook()
for handle in self._remove_handles:
handle.remove()
self._clean_hooks = []
self._remove_handles = []
super()._clean()
is_tensor = lambda a: isinstance(a, torch.Tensor) or isinstance(a, Tensor)
from storch.util import has_backwards_path