diff --git a/.gitignore b/.gitignore index 26ecc8b38..0b7d1aa67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Jupyter notebooks +*.ipynb + # uv uv.lock diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 361743a40..964b94a67 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,6 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge +from torchjd.sparse import make_sst + from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms @@ -173,7 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - jac_output = _make_initial_jac_output(output) + identity = torch.eye(output.ndim, dtype=torch.int64) + strides = torch.concatenate([identity, identity], dim=0) + jac_output = make_sst(torch.ones_like(output), strides) vmapped_diff = differentiation for _ in output_dims: @@ -193,15 +197,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: gramian_computer.reset() return gramian - - -def _make_initial_jac_output(output: Tensor) -> Tensor: - if output.ndim == 0: - return torch.ones_like(output) - p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape] - p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = p_indices_grid + p_indices_grid - - res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype) - res[v_indices_grid] = 1.0 - return res diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..7a161b6ad --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1,3 @@ +# Need to import this to execute the code inside and thus to override the functions +from . import _aten_function_overrides +from ._structured_sparse_tensor import StructuredSparseTensor, make_sst diff --git a/src/torchjd/sparse/_aten_function_overrides/__init__.py b/src/torchjd/sparse/_aten_function_overrides/__init__.py new file mode 100644 index 000000000..b33cf8d62 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/__init__.py @@ -0,0 +1 @@ +from . import backward, einsum, pointwise, shape diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py new file mode 100644 index 000000000..9168c7653 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -0,0 +1,36 @@ +from torch import Tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl + + +@impl(aten.threshold_backward.default) +def threshold_backward_default( + grad_output: StructuredSparseTensor, self: Tensor, threshold +) -> StructuredSparseTensor: + new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) + + return StructuredSparseTensor(new_physical, grad_output.strides) + + +@impl(aten.hardtanh_backward.default) +def hardtanh_backward_default( + grad_output: StructuredSparseTensor, + self: Tensor, + min_val: Tensor | int | float, + max_val: Tensor | int | float, +) -> StructuredSparseTensor: + if isinstance(self, StructuredSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) + return StructuredSparseTensor(new_physical, grad_output.strides) + + +@impl(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor): + if isinstance(self, StructuredSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardswish_backward.default(grad_output.physical, self) + return StructuredSparseTensor(new_physical, grad_output.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py new file mode 100644 index 000000000..45c6cd25d --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -0,0 +1,252 @@ +import torch +from torch import Tensor, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + impl, + to_most_efficient_tensor, + to_structured_sparse_tensor, +) + + +def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor: + + # First part of the algorithm, determine how to cluster physical indices as well as the common + # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. + + # get a map from einsum index to (tensor_idx, v_dims) + # get a map from einsum index to merge of strides corresponding to v_dims with that index + # use to_target_physical_strides on each physical and v_to_ps + # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding + # p_to_vs + # get unique indices + # map output indices (there can be splits) + # call physical einsum + # build resulting sst + + # OVER + + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a + # stride corresponding to the physical stride in the virtual one (note that as the virtual shape + # for two virtual index that match should match, then we want to match the strides and reshape + # accordingly). + # We want to cluster such indices whenever several appear in the same p_to_vs + + # TODO: Handle ellipsis + # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list + # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. + # For this reason, an index is decomposed into sub-indices that are then independently + # clustered. + # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l], + # We will consider three indices (i, 0), (i, 1) and (i, 2). + # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then + # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in + # the resulting einsum). + # Note that this is a problem if two virtual dimensions (from possibly different + # StructuredSparseTensors) have the same size but not the same decomposition into physical + # dimension sizes. For now lets leave the responsibility to care about that in the calling + # functions, if we can factor code later on we will. + + index_parents = dict[tuple[int, int], tuple[int, int]]() + + def get_representative(index: tuple[int, int]) -> tuple[int, int]: + if index not in index_parents: + # If an index is not yet in a cluster, put it in its own. + index_parents[index] = index + current = index_parents[index] + if current != index: + # Compress path to representative + index_parents[index] = get_representative(current) + return index_parents[index] + + def group_indices(indices: list[tuple[int, int]]) -> None: + first_representative = get_representative(indices[0]) + for i in indices[1:]: + curr_representative = get_representative(i) + index_parents[curr_representative] = first_representative + + new_indices_pair = list[list[tuple[int, int]]]() + physicals = list[Tensor]() + indices_to_n_pdims = dict[int, int]() + for t, indices in args: + assert isinstance(t, StructuredSparseTensor) + physicals.append(t.physical) + for pdims, index in zip(t.v_to_ps, indices): + if index in indices_to_n_pdims: + if indices_to_n_pdims[index] != len(pdims): + raise NotImplementedError( + "einsum currently does not support having a different number of physical " + "dimensions corresponding to matching virtual dimensions of different " + f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, " + f"output_indices={output}." + ) + else: + indices_to_n_pdims[index] = len(pdims) + p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps) + for indices_ in p_to_vs: + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid SST. + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) + + current = 0 + pair_to_int = dict[tuple[int, int], int]() + + def unique_int(pair: tuple[int, int]) -> int: + nonlocal current + if pair in pair_to_int: + return pair_to_int[pair] + pair_to_int[pair] = current + current += 1 + return pair_to_int[pair] + + new_indices = [ + [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair + ] + new_output = list[int]() + v_to_ps = list[list[int]]() + for i in output: + current_v_to_ps = [] + for j in range(indices_to_n_pdims[i]): + k = unique_int(get_representative((i, j))) + if k in new_output: + current_v_to_ps.append(new_output.index(k)) + else: + current_v_to_ps.append(len(new_output)) + new_output.append(k) + v_to_ps.append(current_v_to_ps) + + physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output) + # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. + # Maybe there is a way to fix that though. + return to_most_efficient_tensor(physical, v_to_ps) + + +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: + """ + Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being + a SST, Tensor, int or float. + """ + + assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) + + if isinstance(t1, int) or isinstance(t1, float): + t1_ = tensor(t1, device=t2.device) + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = tensor(t2, device=t1.device) + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) + t1_ = to_structured_sparse_tensor(t1_) + t2_ = to_structured_sparse_tensor(t2_) + + return t1_, t2_ + + +@impl(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.mul.Scalar) +def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: + # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check + # that + + assert isinstance(t, StructuredSparseTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return StructuredSparseTensor(new_physical, t.strides) + + +@impl(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> StructuredSparseTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if torch.equal(t1_.strides, t2_.strides): + new_physical = t1_.physical + t2_.physical * alpha + return StructuredSparseTensor(new_physical, t1_.strides) + else: + raise NotImplementedError() + + +@impl(aten.bmm.default) +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert ( + mat1.ndim == 3 + and mat2.ndim == 3 + and mat1.shape[0] == mat2.shape[0] + and mat1.shape[2] == mat2.shape[1] + ) + + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) + + # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes + # decompositions. If not, can reshape to common decomposition? + return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) + + +@impl(aten.mm.default) +def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] + + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) + + return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) + + +@impl(aten.mean.default) +def mean_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + return aten.sum.default(t.physical) / t.numel() + + +@impl(aten.sum.default) +def sum_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + return aten.sum.default(t.physical) + + +@impl(aten.sum.dim_IntList) +def sum_dim_IntList( + t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None +) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + + if dtype: + raise NotImplementedError() + + all_dims = list(range(t.ndim)) + result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) + + if keepdim: + for d in dim: + result = result.unsqueeze(d) + + return result diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py new file mode 100644 index 000000000..9d389c10b --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -0,0 +1,125 @@ +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl + +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = [ + aten.abs.default, + aten.absolute.default, + aten.asin.default, + aten.asinh.default, + aten.atan.default, + aten.atanh.default, + aten.ceil.default, + aten.erf.default, + aten.erfinv.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, + aten.hardtanh.default, + aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, +] + +_IN_PLACE_POINTWISE_FUNCTIONS = [ + aten.abs_.default, + aten.absolute_.default, + aten.asin_.default, + aten.asinh_.default, + aten.atan_.default, + aten.atanh_.default, + aten.ceil_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, +] + + +def _override_pointwise(op): + @impl(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(op(t.physical), t.strides) + + return func_ + + +def _override_inplace_pointwise(op): + @impl(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + op(t.physical) + return t + + +for pointwise_func in _POINTWISE_FUNCTIONS: + _override_pointwise(pointwise_func) + +for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(pointwise_func) + + +@impl(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + return aten.pow.Tensor_Scalar(t.to_dense(), exponent) + + new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) + return StructuredSparseTensor(new_physical, t.strides) + + +# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. +@impl(aten.pow_.Scalar) +def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + # Note sure if it's even possible to densify in-place, so let's just raise an error. + raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") + + aten.pow_.Scalar(t.physical, exponent) + return t + + +@impl(aten.div.Scalar) +def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + new_physical = aten.div.Scalar(t.physical, divisor) + return StructuredSparseTensor(new_physical, t.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py new file mode 100644 index 000000000..a4c255607 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -0,0 +1,289 @@ +import operator +from itertools import accumulate +from math import prod +from typing import cast + +import torch +from torch import Tensor, arange, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + impl, + print_fallback, + to_most_efficient_tensor, + unwrap_to_dense, +) + + +@impl(aten.view.default) +def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + """ + The main condition that we want to respect is that the indexing in the flattened virtual + tensor should remain the same before and after the reshape, i.e. + + c.T S = c'.T S' (1) + where: + * c is the reversed vector of cumulative physical shape before the reshape, i.e. + c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1] + * c' is the same thing but after the reshape, i.e. + c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] + * S is the original matrix of strides (t.strides) + * S' is the matrix of strides after reshaping. + + For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i. + Note that c'.T S' ≡ S'[-1] (mod shape[-1]) + So if we set S'[-1] = c.T S % shape[-1], we have c.T S ≡ c'.T S' (mod shape[-1]) + + (c'.T S' - S'[-1]) // shape[-1] ≡ S'[-1] (mod shape[-1]) + ... + """ + + assert isinstance(t, StructuredSparseTensor) + + shape = infer_shape(shape, t.numel()) + + if prod(shape) != t.numel(): + raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") + + S = t.strides + vshape = list(t.shape) + c = _reverse_cumulative_product(vshape) + c_prime = _reverse_cumulative_product(shape) + new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) + return to_most_efficient_tensor(t.physical, new_strides) + + +def _reverse_cumulative_product(values: list[int]) -> Tensor: + return tensor(list(accumulate((values[1:] + [1])[::-1], operator.mul))[::-1]) + + +def infer_shape(shape: list[int], numel: int) -> list[int]: + if shape.count(-1) > 1: + raise ValueError("Only one dimension can be inferred") + known = 1 + for s in shape: + if s != -1: + known *= s + inferred = numel // known + return [inferred if s == -1 else s for s in shape] + + +def unsquash_pdim( + physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, Tensor]: + """ + EXAMPLE: + + physical = [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + ] + strides = [ + [1, 1], + [0, 2], + ] + + dim = 1 + shape = [2, 3] + + new_physical = [[ + [1, 2, 3], + [4, 5, 6], + ], [ + [7, 8, 9], + [10, 11, 12], + ], [ + [13, 14, 15], + [16, 17, 18], + ]] + + new_strides = [ + [1, 3, 1], + [0, 6, 2] + """ + + # TODO: handle working with multiple dimensions at once + + old_shape = list(physical.shape) + new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + + new_strides = torch.concat( + [ + strides[:, :pdim], + torch.outer(strides[:, pdim], stride_multipliers), + strides[:, pdim + 1 :], + ], + dim=1, + ) + + return new_physical, new_strides + + +@impl(aten._unsafe_view.default) +def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + return view_default( + t, shape + ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp + + +@impl(aten.unsqueeze.default) +def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + assert -t.ndim - 1 <= dim < t.ndim + 1 + + if dim < 0: + dim = t.ndim + dim + 1 + + new_strides = torch.concatenate( + [t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]] + ) + return StructuredSparseTensor(t.physical, new_strides) + + +@impl(aten.squeeze.dims) +def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + + if dims is None: + excluded = set(range(t.ndim)) + elif isinstance(dims, int): + excluded = {dims} + else: + excluded = set(dims) + + is_row_kept = [i not in excluded for i in range(t.ndim)] + new_strides = t.strides[is_row_kept] + return to_most_efficient_tensor(t.physical, new_strides) + + +@impl(aten.permute.default) +def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: + new_strides = t.strides[torch.tensor(dims)] + return StructuredSparseTensor(t.physical, new_strides) + + +@impl(aten.cat.default) +def cat_default(tensors: list[Tensor], dim: int) -> Tensor: + if any(not isinstance(t, StructuredSparseTensor) for t in tensors): + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + tensors_ = [cast(StructuredSparseTensor, t) for t in tensors] + ref_tensor = tensors_[0] + ref_strides = ref_tensor.strides + if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): + raise NotImplementedError( + "Override for aten.cat.default does not support SSTs that do not all have the same " + f"strides. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " + f"following dim: {dim}." + ) + + # We need to try to find the (pretty sure it either does not exist or is unique) physical + # dimension that makes us only move on virtual dimension dim. It also needs to be such that + # traversing it entirely brings us exactly to the end of virtual dimension dim. + + ref_virtual_dim_size = ref_tensor.shape[dim] + indices = torch.argwhere( + torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + ) + assert len(indices) <= 1 + + if len(indices) == 0: + # Add a physical dimension pdim on which we can concatenate the physicals such that this + # translates into a concatenation of the virtuals on virtual dimension dim. + + pdim = ref_tensor.physical.ndim + physicals = [t.physical.unsqueeze(-1) for t in tensors_] + new_stride_column = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) + new_stride_column[dim, 0] = ref_virtual_dim_size + new_strides = torch.concatenate([ref_tensor.strides, new_stride_column], dim=1) + else: + # Such a physical dimension already exists. Note that an alternative implementation would be + # to simply always add the physical dimension, and squash it if it ends up being not needed. + physicals = [t.physical for t in tensors_] + pdim = indices[0][0] + new_strides = ref_tensor.strides + + new_physical = aten.cat.default(physicals, dim=pdim) + return StructuredSparseTensor(new_physical, new_strides) + + +@impl(aten.expand.default) +def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor: + # note that sizes could also be just an int, or a torch.Size i think + assert isinstance(t, StructuredSparseTensor) + assert isinstance(sizes, list) + assert len(sizes) >= t.ndim + + # Add as many dimensions as needed at the beginning of the tensor (as torch.expand works) + for _ in range(len(sizes) - t.ndim): + t = t.unsqueeze(0) + + # Try to expand each dimension to its new size + new_physical = t.physical + new_strides = t.strides + for d, (vstride, orig_size, new_size) in enumerate(zip(t.strides, t.shape, sizes, strict=True)): + if vstride.sum() > 0 and orig_size != new_size and new_size != -1: + raise ValueError( + f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size " + f"{new_size}." + ) + + if vstride.sum() == 0 and new_size != 1 and new_size != -1: + # Add a dimension of size new_size at the end of the physical tensor. + new_physical_shape = list(new_physical.shape) + [new_size] + new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) + + # Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at + # every other virtual dimension + new_stride_column = torch.zeros(t.ndim, 1, dtype=torch.int64) + new_stride_column[d, 0] = 1 + new_strides = torch.cat([new_strides, new_stride_column], dim=1) + + return StructuredSparseTensor(new_physical, new_strides) + + +@impl(aten.broadcast_tensors.default) +def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + if len(tensors) != 2: + raise NotImplementedError() + + t1, t2 = tensors + + if t1.shape == t2.shape: + return t1, t2 + + a = t1 if t1.ndim >= t2.ndim else t2 + b = t2 if t1.ndim >= t2.ndim else t1 + + a_shape = list(a.shape) + padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) + + new_shape = list[int]() + + for s_a, s_b in zip(a_shape, padded_b_shape): + if s_a != 1 and s_b != 1 and s_a != s_b: + raise ValueError("Incompatible shapes for broadcasting") + else: + new_shape.append(max(s_a, s_b)) + + return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) + + +@impl(aten.transpose.int) +def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) + + +def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: + index = arange(matrix.shape[0]) + index[c0] = c1 + index[c1] = c0 + return matrix[index] diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py new file mode 100644 index 000000000..11ad01b2a --- /dev/null +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -0,0 +1,304 @@ +import itertools +import operator +from functools import wraps +from itertools import accumulate +from math import prod + +import torch +from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros +from torch.utils._pytree import tree_map + + +class StructuredSparseTensor(Tensor): + _HANDLED_FUNCTIONS = dict() + + @staticmethod + def __new__(cls, physical: Tensor, strides: Tensor): + # At the moment, this class is not compositional, so we assert + # that the tensor we're wrapping is exactly a Tensor + assert type(physical) is Tensor + + # Note [Passing requires_grad=true tensors to subclasses] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Calling _make_subclass directly in an autograd context is + # never the right thing to do, as this will detach you from + # the autograd graph. You must create an autograd function + # representing the "constructor" (NegativeView, in this case) + # and call that instead. This assert helps prevent direct usage + # (which is bad!) + assert not physical.requires_grad or not torch.is_grad_enabled() + + pshape = torch.tensor(physical.shape) + vshape = strides @ (pshape - 1) + 1 + return Tensor._make_wrapper_subclass( + cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device + ) + + def __init__(self, physical: Tensor, strides: Tensor): + """ + This constructor is made for specifying physical and strides exactly. It should not modify + it. + + For this reason, another constructor will be made to either modify the physical / strides to + simplify the result, or to create a dense tensor directly if it's already dense. + + :param physical: The dense tensor holding the actual data. + :param strides: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing + the linear transformation between an index in the physical tensor and the corresponding + index in the virtual tensor, i.e. v_index = strides @ p_index. + """ + + if any(s == 1 for s in physical.shape): + raise ValueError( + "physical must not contain any dimension of size 1. Found physical.shape=" + f"{physical.shape}." + ) + if strides.dtype is not torch.int64: + raise ValueError( + f"strides should be of int64 dtype. Found strides.dtype={strides.dtype}." + ) + if not (strides >= 0).all(): + raise ValueError(f"All strides must be non-negative. Found strides={strides}.") + if strides.shape[1] != physical.ndim: + raise ValueError( + f"strides should have 1 column per physical dimension. Found strides={strides} and physical.shape={physical.shape}." + ) + if (strides.sum(dim=0) == 0).any(): + raise ValueError( + f"strides should not have any column full of zeros. Found strides={strides}." + ) + if any(len(group) != 1 for group in get_groupings(list(physical.shape), strides)): + raise ValueError(f"Dimensions must be maximally grouped. Found strides={strides}.") + + self.physical = physical + self.strides = strides + + def to_dense( + self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None + ) -> Tensor: + assert dtype is None # We may add support for this later + assert masked_grad is None # We may add support for this later + + if self.physical.ndim == 0: + return self.physical + + p_index_ranges = [arange(s) for s in self.physical.shape] + p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) + + # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu + v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) + res = zeros(self.shape, device=self.device, dtype=self.dtype) + res[tuple(v_indices_grid)] = self.physical + return res + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in cls._HANDLED_FUNCTIONS: + return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) + + print_fallback(func, args, kwargs) + unwrapped_args = tree_map(unwrap_to_dense, args) + unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) + return func(*unwrapped_args, **unwrapped_kwargs) + + def __repr__(self, *, tensor_contents=None) -> str: + return f"StructuredSparseTensor(physical={self.physical}, strides={self.strides})" + + def debug_info(self) -> str: + info = ( + f"vshape: {self.shape}\n" + f"pshape: {self.physical.shape}\n" + f"strides: {self.strides}\n" + ) + return info + + @classmethod + def implements(cls, torch_function): + """Register a torch function override for ScalarTensor""" + + @wraps(torch_function) + def decorator(func): + cls._HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +impl = StructuredSparseTensor.implements + + +def print_fallback(func, args, kwargs) -> None: + def tensor_to_str(t: Tensor) -> str: + result = f"{t.__class__.__name__} - vshape: {t.shape}" + if isinstance(t, StructuredSparseTensor): + result += f" - pshape: {t.physical.shape} - strides: {t.strides}" + + return result + + print(f"Falling back to dense for {func.__name__}") + if len(args) > 0: + print("* args:") + for arg in args: + if isinstance(arg, Tensor): + print(f" > {tensor_to_str(arg)}") + elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): + list_content = "\n ".join([tensor_to_str(t) for t in arg]) + print(f" > [{list_content}]") + else: + print(f" > {arg}") + if len(kwargs) > 0: + print("* kwargs:") + for k, v in kwargs.items(): + print(f" > {k}: {v}") + print() + + +def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: + """ + From a list of physical dimensions corresponding to a virtual dimension, and from the physical + shape, get the stride indicating how moving on each physical dimension makes you move on the + virtual dimension. + + Example: + Imagine a vector of size 3, and of value [1, 2, 3]. + Imagine a SST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. + t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix + [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). + When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. + strides_v2([0, 0], [3]) = 4 + In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index). + + Example: + strides_v2([0, 0, 1], [3,4]) # [16, 1] + Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by + 1 on physical dimension 1 makes you move by 1 on the virtual dimension. + """ + + strides_v1 = list(accumulate([1] + [physical_shape[d] for d in p_dims[:0:-1]], operator.mul))[ + ::-1 + ] + result = [0 for _ in range(len(physical_shape))] + for i, d in enumerate(p_dims): + result[d] += strides_v1[i] + return result + + +def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: + strides_time_pshape = strides * tensor(pshape) + groups = {i: {i} for i, column in enumerate(strides.T)} + group_ids = [i for i in range(len(strides.T))] + for i1, i2 in itertools.combinations(range(strides.shape[1]), 2): + if torch.equal(strides[:, i1], strides_time_pshape[:, i2]): + groups[group_ids[i1]].update(groups[group_ids[i2]]) + group_ids[i2] = group_ids[i1] + + new_columns = [sorted(groups[group_id]) for group_id in sorted(set(group_ids))] + + if len(new_columns) != len(pshape): + print(f"Combined pshape with the following new columns: {new_columns}.") + + return new_columns + + +def to_structured_sparse_tensor(t: Tensor) -> StructuredSparseTensor: + if isinstance(t, StructuredSparseTensor): + return t + else: + return make_sst(physical=t, strides=torch.eye(t.ndim, dtype=torch.int64)) + + +def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor: + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + + if (strides.sum(dim=0) == 1).all(): + # All physical dimensions make you move by 1 in exactly 1 virtual dimension. + # Also, because all physical dimensions have been maximally grouped, we cannot have two + # physical dimensions that make you move in the same virtual dimension. + # So strides is an identity matrix with potentially some extra rows of zeros, and + # potentially shuffled columns. + + # The first step is to unsqueeze the physical tensor for each extra row of zeros in the + # strides. + zero_row_mask = strides.sum(dim=1) == 0 + number_of_zero_rows = zero_row_mask.sum() + for _ in number_of_zero_rows: + physical = physical.unsqueeze(-1) + + # The second step is to re-order the physical dimensions so that the corresponding + # strides matrix would be an identity. + source = arange(strides.shape[0]) + destination = strides[zero_row_mask] @ source + return physical.movedim(list(source), list(destination)) + else: + return StructuredSparseTensor(physical, strides) + + +def unwrap_to_dense(t: Tensor): + if isinstance(t, StructuredSparseTensor): + return t.to_dense() + else: + return t + + +def get_full_source(source: list[int], destination: list[int], ndim: int) -> list[int]: + """ + Doing a movedim with source and destination is always equivalent to doing a movedim with + [0, 1, ..., ndim-1] (aka "full_destination") as destination, and the "full_source" as source. + + This function computes the full_source based on a source and destination. + + Example: + source=[2, 4] + destination=[0, 3] + ndim=5 + + full_source = [2, 0, 1, 4, 3] + full_destination = [0, 1, 2, 3, 4] + """ + + idx = torch.full((ndim,), -1, dtype=torch.int64) + idx[destination] = tensor(source) + source_set = set(source) + idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set]) + + return idx.tolist() + + +def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + is_of_size_1 = torch.tensor([s == 1 for s in physical.shape]) + return physical.squeeze(), strides[:, ~is_of_size_1] + + +def fix_ungrouped_dims(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + groups = get_groupings(list(physical.shape), strides) + nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) + stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) + for j, group in enumerate(groups): + stride_mapping[group[-1], j] = 1 + + new_strides = strides @ stride_mapping + return nphysical, new_strides + + +def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor: + """Fix physical and strides and create a StructuredSparseTensor with them.""" + + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + return StructuredSparseTensor(physical, strides) + + +def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + """Remove columns of strides that are all 0 and sum the corresponding elements in the physical tensor.""" + are_columns_zero = (strides == 0).all(dim=0) + + if not (are_columns_zero).any(): + return physical, strides + + zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist() + physical = physical.sum(dim=zero_column_indices) + strides = strides[:, ~are_columns_zero] + return physical, strides diff --git a/tests/unit/sparse/__init__.py b/tests/unit/sparse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py new file mode 100644 index 000000000..00f6112da --- /dev/null +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -0,0 +1,423 @@ +import torch +from pytest import mark +from torch import Tensor, tensor +from torch.ops import aten # type: ignore +from torch.testing import assert_close +from utils.tensors import randn_, tensor_, zeros_ + +from torchjd.sparse._aten_function_overrides.einsum import einsum +from torchjd.sparse._aten_function_overrides.pointwise import ( + _IN_PLACE_POINTWISE_FUNCTIONS, + _POINTWISE_FUNCTIONS, +) +from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + fix_ungrouped_dims, + fix_zero_stride_columns, + get_full_source, + get_groupings, +) + + +def test_to_dense(): + n = 2 + m = 3 + a = randn_([n, m]) + b = StructuredSparseTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) + c = b.to_dense() + + for i in range(n): + for j in range(m): + assert c[i, j, j, i] == a[i, j] + + +def test_to_dense2(): + a = tensor_([1.0, 2.0, 3.0]) + b = StructuredSparseTensor(a, tensor([[4]])) + c = b.to_dense() + expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) + assert torch.all(torch.eq(c, expected)) + + +@mark.parametrize( + ["a_pshape", "a_strides", "b_pshape", "b_strides", "a_indices", "b_indices", "output_indices"], + [ + ( + [4, 5], + tensor([[1, 0], [1, 0], [0, 1]]), + [4, 5], + tensor([[1, 0], [0, 1], [0, 1]]), + [0, 1, 2], + [0, 2, 3], + [0, 1, 3], + ), + ( + [2, 3, 5], + tensor([[3, 1, 0], [1, 0, 2]]), + [10, 3], + tensor([[1, 0], [0, 1]]), + [0, 1], + [1, 2], + [0, 2], + ), + ([2, 3], tensor([[3, 1]]), [6], tensor([[1]]), [0], [0], []), + ( + [6, 2, 3], + tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [2, 3], + tensor([[3, 1], [1, 0], [0, 1]]), + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ), + ], +) +def test_einsum( + a_pshape: list[int], + a_strides: Tensor, + b_pshape: list[int], + b_strides: Tensor, + a_indices: list[int], + b_indices: list[int], + output_indices: list[int], +): + a = StructuredSparseTensor(randn_(a_pshape), a_strides) + b = StructuredSparseTensor(randn_(b_pshape), b_strides) + + res = einsum((a, a_indices), (b, b_indices), output=output_indices) + + expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) + + assert isinstance(res, StructuredSparseTensor) + assert_close(res.to_dense(), expected) + + +@mark.parametrize( + "shape", + [ + [], + [2], + [2, 3], + [2, 3, 4], + ], +) +def test_structured_sparse_tensor_scalar(shape: list[int]): + a = randn_(shape) + b = StructuredSparseTensor(a, torch.eye(len(shape), dtype=torch.int64)) + + assert_close(a, b.to_dense()) + + +@mark.parametrize("dim", [2, 3, 4, 5, 10]) +def test_diag_equivalence(dim: int): + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + + diag_a = torch.diag(a) + + assert_close(b.to_dense(), diag_a) + + +def test_three_virtual_single_physical(): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1], [1]])) + + expected = zeros_([dim, dim, dim]) + for i in range(dim): + expected[i, i, i] = a[i] + + assert_close(b.to_dense(), expected) + + +@mark.parametrize("func", _POINTWISE_FUNCTIONS) +def test_pointwise(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + res = func(b) + assert isinstance(res, StructuredSparseTensor) + + assert_close(res.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", _IN_PLACE_POINTWISE_FUNCTIONS) +def test_inplace_pointwise(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + func(b) + assert isinstance(b, StructuredSparseTensor) + + assert_close(b.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", [torch.mean, torch.sum]) +def test_unary(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + + res = func(b) + assert_close(res.to_dense(), func(c)) + + +@mark.parametrize( + ["physical_shape", "strides", "target_shape", "expected_physical_shape", "expected_strides"], + [ + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # squashing 2 dims + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [12], + [2, 3], + tensor([[9, 1]]), + ), # squashing 3 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 3 dims + ( + [4], + tensor([[1], [1]]), + [2, 2, 4], + [2, 2], + tensor([[1, 0], [0, 1], [2, 1]]), + ), # unsquashing physical dim + ( + [4], + tensor([[1], [1]]), + [4, 2, 2], + [2, 2], + tensor([[2, 1], [1, 0], [0, 1]]), + ), # unsquashing physical dim + ( + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [4, 12], + [2, 12], + tensor([[3, 0], [0, 1]]), + ), # world boss + ( + [2, 12], + tensor([[3, 0], [0, 1]]), + [2, 2, 3, 4], + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ), # world boss + ], +) +def test_view( + physical_shape: list[int], + strides: Tensor, + target_shape: list[int], + expected_physical_shape: list[int], + expected_strides: Tensor, +): + a = randn_(tuple(physical_shape)) + t = StructuredSparseTensor(a, strides) + + result = aten.view.default(t, target_shape) + expected = t.to_dense().reshape(target_shape) + + assert isinstance(result, StructuredSparseTensor) + assert list(result.physical.shape) == expected_physical_shape + assert torch.equal(result.strides, expected_strides) + assert torch.all(torch.eq(result.to_dense(), expected)) + + +@mark.parametrize( + ["pshape", "strides", "expected"], + [ + ( + [[32, 2, 3, 4, 5]], + torch.tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 60, 20, 5, 1]]), + [[0], [1, 2, 3, 4]], + ) + ], +) +def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[list[int]]): + result = get_groupings(pshape, strides) + assert result == expected + + +@mark.parametrize( + ["physical_shape", "strides", "expected_physical_shape", "expected_strides"], + [ + ( + [3, 4, 5], + tensor([[20, 5, 1], [4, 1, 12], [0, 0, 1]]), + [12, 5], + tensor([[5, 1], [1, 12], [0, 1]]), + ), + ( + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + ), + ([3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]]), [3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]])), + ], +) +def test_fix_ungrouped_dims( + physical_shape: list[int], + strides: Tensor, + expected_physical_shape: list[int], + expected_strides: Tensor, +): + physical = randn_(physical_shape) + fixed_physical, fixed_strides = fix_ungrouped_dims(physical, strides) + + assert list(fixed_physical.shape) == expected_physical_shape + assert torch.equal(fixed_strides, expected_strides) + + +@mark.parametrize( + [ + "physical_shape", + "strides", + "pdim", + "new_pdim_shape", + "expected_physical_shape", + "expected_strides", + ], + [ + ([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial + ([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])), + ( + [3, 4, 5], + tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]), + 1, + [2, 1, 1, 2], + [3, 2, 1, 1, 2, 5], + tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]), + ), + ], +) +def test_unsquash_pdim( + physical_shape: list[int], + strides: Tensor, + pdim: int, + new_pdim_shape: list[int], + expected_physical_shape: list[int], + expected_strides: Tensor, +): + physical = randn_(physical_shape) + new_physical, new_strides = unsquash_pdim(physical, strides, pdim, new_pdim_shape) + + assert list(new_physical.shape) == expected_physical_shape + assert torch.equal(new_strides, expected_strides) + + +@mark.parametrize( + [ + "source", + "destination", + "ndim", + ], + [ + ([2, 4], [0, 3], 5), + ([5, 3, 6], [2, 0, 5], 8), + ], +) +def test_get_column_indices(source: list[int], destination: list[int], ndim: int): + # TODO: this test should be improved / removed. It creates quite big tensors for nothing. + + t = randn_(list(torch.randint(3, 8, size=(ndim,)))) + full_destination = list(range(ndim)) + full_source = get_full_source(source, destination, ndim) + assert torch.equal(t.movedim(full_source, full_destination), t.movedim(source, destination)) + + +@mark.parametrize( + ["sst_args", "dim"], + [ + ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1), + ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1), + ], +) +def test_concatenate( + sst_args: list[tuple[list[int], Tensor]], + dim: int, +): + tensors = [StructuredSparseTensor(randn_(pshape), strides) for pshape, strides in sst_args] + res = aten.cat.default(tensors, dim) + expected = aten.cat.default([t.to_dense() for t in tensors], dim) + + assert isinstance(res, StructuredSparseTensor) + assert torch.all(torch.eq(res.to_dense(), expected)) + + +@mark.parametrize( + ["physical", "strides", "expected_physical", "expected_strides"], + [ + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 0], [1, 0], [2, 0]]), + tensor_([6, 15]), + tensor([[1], [1], [2]]), + ), + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + ), + ( + tensor_([[3, 2, 1], [6, 5, 4]]), + tensor([[0, 0], [0, 0], [0, 0]]), + tensor_(21), + tensor([[], [], []], dtype=torch.int64), + ), + ], +) +def test_fix_zero_stride_columns( + physical: Tensor, + strides: Tensor, + expected_physical: Tensor, + expected_strides: Tensor, +): + physical, strides = fix_zero_stride_columns(physical, strides) + assert torch.equal(physical, expected_physical) + assert torch.equal(strides, expected_strides)