diff --git a/.gitignore b/.gitignore index 902e607c9..01f539d22 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Jupyter notebooks +*.ipynb + # uv uv.lock diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 361743a40..6574bd498 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,6 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge +from torchjd.sparse import make_slt + from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms @@ -173,7 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - jac_output = _make_initial_jac_output(output) + identity = torch.eye(output.ndim, dtype=torch.int64) + basis = torch.concatenate([identity, identity], dim=0) + jac_output = make_slt(torch.ones_like(output), basis, None) vmapped_diff = differentiation for _ in output_dims: @@ -193,15 +197,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: gramian_computer.reset() return gramian - - -def _make_initial_jac_output(output: Tensor) -> Tensor: - if output.ndim == 0: - return torch.ones_like(output) - p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape] - p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = p_indices_grid + p_indices_grid - - res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype) - res[v_indices_grid] = 1.0 - return res diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..537a29ce3 --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1,3 @@ +# Need to import this to execute the code inside and thus to override the functions +from . import _aten_function_overrides +from ._sparse_latticed_tensor import SparseLatticedTensor, make_slt diff --git a/src/torchjd/sparse/_aten_function_overrides/__init__.py b/src/torchjd/sparse/_aten_function_overrides/__init__.py new file mode 100644 index 000000000..b33cf8d62 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/__init__.py @@ -0,0 +1 @@ +from . import backward, einsum, pointwise, shape diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py new file mode 100644 index 000000000..beae4ef1c --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -0,0 +1,36 @@ +from torch import Tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl + + +@impl(aten.threshold_backward.default) +def threshold_backward_default( + grad_output: SparseLatticedTensor, self: Tensor, threshold +) -> SparseLatticedTensor: + new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) + + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) + + +@impl(aten.hardtanh_backward.default) +def hardtanh_backward_default( + grad_output: SparseLatticedTensor, + self: Tensor, + min_val: Tensor | int | float, + max_val: Tensor | int | float, +) -> SparseLatticedTensor: + if isinstance(self, SparseLatticedTensor): + raise NotImplementedError() + + new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) + + +@impl(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): + if isinstance(self, SparseLatticedTensor): + raise NotImplementedError() + + new_physical = aten.hardswish_backward.default(grad_output.physical, self) + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py new file mode 100644 index 000000000..afbdf7a9a --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -0,0 +1,257 @@ +import torch +from torch import Tensor, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, + impl, + to_most_efficient_tensor, + to_sparse_latticed_tensor, +) + + +def einsum(*args: tuple[SparseLatticedTensor, list[int]], output: list[int]) -> Tensor: + raise NotImplementedError() + + # First part of the algorithm, determine how to cluster physical indices as well as the common + # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. + + # get a map from einsum index to (tensor_idx, v_dims) + # get a map from einsum index to merge of strides corresponding to v_dims with that index + # use to_target_physical_strides on each physical and v_to_ps + # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding + # p_to_vs + # get unique indices + # map output indices (there can be splits) + # call physical einsum + # build resulting sst + + # OVER + + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a + # stride corresponding to the physical stride in the virtual one (note that as the virtual shape + # for two virtual index that match should match, then we want to match the strides and reshape + # accordingly). + # We want to cluster such indices whenever several appear in the same p_to_vs + + # TODO: Handle ellipsis + # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list + # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. + # For this reason, an index is decomposed into sub-indices that are then independently + # clustered. + # So if an index i in args for some SparseLatticedTensor corresponds to a v_to_ps [j, k, l], + # We will consider three indices (i, 0), (i, 1) and (i, 2). + # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then + # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in + # the resulting einsum). + # Note that this is a problem if two virtual dimensions (from possibly different + # SparseLatticedTensors) have the same size but not the same decomposition into physical + # dimension sizes. For now lets leave the responsibility to care about that in the calling + # functions, if we can factor code later on we will. + + index_parents = dict[tuple[int, int], tuple[int, int]]() + + def get_representative(index: tuple[int, int]) -> tuple[int, int]: + if index not in index_parents: + # If an index is not yet in a cluster, put it in its own. + index_parents[index] = index + current = index_parents[index] + if current != index: + # Compress path to representative + index_parents[index] = get_representative(current) + return index_parents[index] + + def group_indices(indices: list[tuple[int, int]]) -> None: + first_representative = get_representative(indices[0]) + for i in indices[1:]: + curr_representative = get_representative(i) + index_parents[curr_representative] = first_representative + + new_indices_pair = list[list[tuple[int, int]]]() + physicals = list[Tensor]() + indices_to_n_pdims = dict[int, int]() + for t, indices in args: + assert isinstance(t, SparseLatticedTensor) + physicals.append(t.physical) + for pdims, index in zip(t.v_to_ps, indices): + if index in indices_to_n_pdims: + if indices_to_n_pdims[index] != len(pdims): + raise NotImplementedError( + "einsum currently does not support having a different number of physical " + "dimensions corresponding to matching virtual dimensions of different " + f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, " + f"output_indices={output}." + ) + else: + indices_to_n_pdims[index] = len(pdims) + p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps) + for indices_ in p_to_vs: + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid SST. + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) + + current = 0 + pair_to_int = dict[tuple[int, int], int]() + + def unique_int(pair: tuple[int, int]) -> int: + nonlocal current + if pair in pair_to_int: + return pair_to_int[pair] + pair_to_int[pair] = current + current += 1 + return pair_to_int[pair] + + new_indices = [ + [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair + ] + new_output = list[int]() + v_to_ps = list[list[int]]() + for i in output: + current_v_to_ps = [] + for j in range(indices_to_n_pdims[i]): + k = unique_int(get_representative((i, j))) + if k in new_output: + current_v_to_ps.append(new_output.index(k)) + else: + current_v_to_ps.append(len(new_output)) + new_output.append(k) + v_to_ps.append(current_v_to_ps) + + physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output) + # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. + # Maybe there is a way to fix that though. + return to_most_efficient_tensor(physical, v_to_ps) + + +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[SparseLatticedTensor, SparseLatticedTensor]: + """ + Prepares two SLTs of the same shape from two args, one of those being a SLT, and the other being + a SLT, Tensor, int or float. + """ + + assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor) + + if isinstance(t1, int) or isinstance(t1, float): + t1_ = tensor(t1, device=t2.device) # type: ignore[union-attr] + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = tensor(t2, device=t1.device) # type: ignore[union-attr] + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) + t1_ = to_sparse_latticed_tensor(t1_) + t2_ = to_sparse_latticed_tensor(t2_) + + return t1_, t2_ + + +@impl(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.margin) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.mul.Scalar) +def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: + # TODO: maybe it could be that scalar is a scalar SLT and t is a normal tensor. Need to check + # that + + assert isinstance(t, SparseLatticedTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return SparseLatticedTensor(new_physical, t.basis, t.margin) + + +@impl(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> SparseLatticedTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if ( + torch.equal(t1_.basis, t2_.basis) + and torch.equal(t1_.offset, t2_.offset) + and torch.equal(t1_.shape_t, t2_.shape_t) + ): + new_physical = t1_.physical + t2_.physical * alpha + return SparseLatticedTensor(new_physical, t1_.basis, t1_.margin) + else: + raise NotImplementedError() + + +@impl(aten.bmm.default) +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor) + assert ( + mat1.ndim == 3 + and mat2.ndim == 3 + and mat1.shape[0] == mat2.shape[0] + and mat1.shape[2] == mat2.shape[1] + ) + + mat1_ = to_sparse_latticed_tensor(mat1) + mat2_ = to_sparse_latticed_tensor(mat2) + + # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes + # decompositions. If not, can reshape to common decomposition? + return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) + + +@impl(aten.mm.default) +def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor) + assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] + + mat1_ = to_sparse_latticed_tensor(mat1) + mat2_ = to_sparse_latticed_tensor(mat2) + + return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) + + +@impl(aten.mean.default) +def mean_default(t: SparseLatticedTensor) -> Tensor: + assert isinstance(t, SparseLatticedTensor) + return aten.sum.default(t.physical) / t.numel() + + +@impl(aten.sum.default) +def sum_default(t: SparseLatticedTensor) -> Tensor: + assert isinstance(t, SparseLatticedTensor) + return aten.sum.default(t.physical) + + +@impl(aten.sum.dim_IntList) +def sum_dim_IntList( + t: SparseLatticedTensor, dim: list[int], keepdim: bool = False, dtype=None +) -> Tensor: + assert isinstance(t, SparseLatticedTensor) + + if dtype: + raise NotImplementedError() + + all_dims = list(range(t.ndim)) + result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) + + if keepdim: + for d in dim: + result = result.unsqueeze(d) + + return result diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py new file mode 100644 index 000000000..44874f335 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -0,0 +1,125 @@ +from torch.ops import aten # type: ignore + +from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl + +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = [ + aten.abs.default, + aten.absolute.default, + aten.asin.default, + aten.asinh.default, + aten.atan.default, + aten.atanh.default, + aten.ceil.default, + aten.erf.default, + aten.erfinv.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, + aten.hardtanh.default, + aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, +] + +_IN_PLACE_POINTWISE_FUNCTIONS = [ + aten.abs_.default, + aten.absolute_.default, + aten.asin_.default, + aten.asinh_.default, + aten.atan_.default, + aten.atanh_.default, + aten.ceil_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, +] + + +def _override_pointwise(op): + @impl(op) + def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + return SparseLatticedTensor(op(t.physical), t.basis, t.margin) + + return func_ + + +def _override_inplace_pointwise(op): + @impl(op) + def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + op(t.physical) + return t + + +for pointwise_func in _POINTWISE_FUNCTIONS: + _override_pointwise(pointwise_func) + +for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(pointwise_func) + + +@impl(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + return aten.pow.Tensor_Scalar(t.to_dense(), exponent) + + new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) + return SparseLatticedTensor(new_physical, t.basis, t.margin) + + +# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. +@impl(aten.pow_.Scalar) +def pow__Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + # Note sure if it's even possible to densify in-place, so let's just raise an error. + raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") + + aten.pow_.Scalar(t.physical, exponent) + return t + + +@impl(aten.div.Scalar) +def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + + new_physical = aten.div.Scalar(t.physical, divisor) + return SparseLatticedTensor(new_physical, t.basis, t.margin) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py new file mode 100644 index 000000000..07e70f7b9 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -0,0 +1,207 @@ +import operator +from itertools import accumulate +from math import prod + +import torch +from torch import Tensor, arange, cat, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, + impl, + print_fallback, + to_most_efficient_tensor, + unwrap_to_dense, +) + + +@impl(aten.view.default) +def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: + """ + The main condition that we want to respect is that the indexing in the flattened virtual + tensor should remain the same before and after the reshape, i.e. + + c.T S = c'.T S' (1) + where: + * c is the reversed vector of cumulative physical shape before the reshape, i.e. + c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1] + * c' is the same thing but after the reshape, i.e. + c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] + * S is the original basis matrix (t.basis) + * S' is the basis matrix after reshaping. + + For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i. + Note that c'.T S' ≡ S'[-1] (mod shape[-1]) + So if we set S'[-1] = c.T S % shape[-1], we have c.T S ≡ c'.T S' (mod shape[-1]) + + (c'.T S' - S'[-1]) // shape[-1] ≡ S'[-1] (mod shape[-1]) + ... + """ + + assert isinstance(t, SparseLatticedTensor) + + if not torch.equal(t.margin, torch.zeros_like(t.margin)): + raise NotImplementedError() + + shape = infer_shape(shape, t.numel()) + + if prod(shape) != t.numel(): + raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") + + S = t.basis + vshape = list(t.shape) + c = _reverse_cumulative_product(vshape) + c_prime = _reverse_cumulative_product(shape) + new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) + + new_margin = torch.zeros([len(shape), 2], dtype=torch.int64) + return to_most_efficient_tensor(t.physical, new_basis, new_margin) + + +def _reverse_cumulative_product(values: list[int]) -> Tensor: + return tensor(list(accumulate((values[1:] + [1])[::-1], operator.mul))[::-1]) + + +def infer_shape(shape: list[int], numel: int) -> list[int]: + if shape.count(-1) > 1: + raise ValueError("Only one dimension can be inferred") + known = 1 + for s in shape: + if s != -1: + known *= s + inferred = numel // known + return [inferred if s == -1 else s for s in shape] + + +@impl(aten._unsafe_view.default) +def _unsafe_view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: + return view_default( + t, shape + ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp + + +@impl(aten.unsqueeze.default) +def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + assert -t.ndim - 1 <= dim < t.ndim + 1 + + if dim < 0: + dim = t.ndim + dim + 1 + + pdims = t.basis.shape[1] + new_basis = cat([t.basis[:dim], torch.zeros(1, pdims, dtype=torch.int64), t.basis[dim:]]) + new_margin = cat([t.margin[:dim], torch.zeros([1, 2], dtype=torch.int64), t.margin[dim:]]) + return SparseLatticedTensor(t.physical, new_basis, new_margin) + + +@impl(aten.squeeze.dims) +def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, SparseLatticedTensor) + # TODO: verify that the specified dimensions are of size 1. + + if dims is None: + excluded = set(range(t.ndim)) + elif isinstance(dims, int): + excluded = {dims} + else: + excluded = set(dims) + + is_row_kept = [i not in excluded for i in range(t.ndim)] + new_basis = t.basis[is_row_kept] + new_margin = t.margin[is_row_kept] + return to_most_efficient_tensor(t.physical, new_basis, new_margin) + + +@impl(aten.permute.default) +def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: + new_basis = t.basis[dims] + new_margin = t.margin[dims] + return SparseLatticedTensor(t.physical, new_basis, new_margin) + + +@impl(aten.cat.default) +def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor: + if any(not isinstance(t, SparseLatticedTensor) for t in tensors): + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + # TODO: add implementation based on adding some margin to tensors and summing them + + +@impl(aten.expand.default) +def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedTensor: + # note that sizes could also be just an int, or a torch.Size i think + assert isinstance(t, SparseLatticedTensor) + assert isinstance(sizes, list) + assert len(sizes) >= t.ndim + + # Add as many dimensions as needed at the beginning of the tensor (as torch.expand works) + for _ in range(len(sizes) - t.ndim): + t = unsqueeze_default(t, 0) + + # Try to expand each dimension to its new size + new_physical = t.physical + new_basis = t.basis + for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)): + if v.sum() > 0 and orig_size != new_size and new_size != -1: + raise ValueError( + f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size " + f"{new_size}." + ) + + if v.sum() == 0 and new_size != 1 and new_size != -1: + # Add a dimension of size new_size at the end of the physical tensor. + new_physical_shape = list(new_physical.shape) + [new_size] + new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) + + # Make the basis vector of this new physical dimension be 1 at virtual dimension d and 0 + # at every other virtual dimension + new_basis_vector = torch.zeros(t.ndim, 1, dtype=torch.int64) + new_basis_vector[d, 0] = 1 + new_basis = torch.cat([new_basis, new_basis_vector], dim=1) + + return SparseLatticedTensor(new_physical, new_basis, t.margin) + + +@impl(aten.broadcast_tensors.default) +def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + if len(tensors) != 2: + raise NotImplementedError() + + t1, t2 = tensors + + if t1.shape == t2.shape: + return t1, t2 + + a = t1 if t1.ndim >= t2.ndim else t2 + b = t2 if t1.ndim >= t2.ndim else t1 + + a_shape = list(a.shape) + padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) + + new_shape = list[int]() + + for s_a, s_b in zip(a_shape, padded_b_shape): + if s_a != 1 and s_b != 1 and s_a != s_b: + raise ValueError("Incompatible shapes for broadcasting") + else: + new_shape.append(max(s_a, s_b)) + + return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) + + +@impl(aten.transpose.int) +def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + + new_index = arange(t.basis.shape[0]) + new_index[dim0] = dim1 + new_index[dim1] = dim0 + + new_basis = t.basis[new_index] + new_margin = t.margin[new_index] + + return SparseLatticedTensor(t.physical, new_basis, new_margin) diff --git a/src/torchjd/sparse/_coalesce.py b/src/torchjd/sparse/_coalesce.py new file mode 100644 index 000000000..ac9c8368f --- /dev/null +++ b/src/torchjd/sparse/_coalesce.py @@ -0,0 +1,18 @@ +import torch +from torch import Tensor + + +def fix_zero_basis_vectors(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + """ + Remove basis vectors that are all 0 and sum the corresponding elements in the physical tensor. + """ + + are_vectors_zero = (basis == 0).all(dim=0) + + if not are_vectors_zero.any(): + return physical, basis + + zero_column_indices = torch.arange(len(are_vectors_zero))[are_vectors_zero].tolist() + physical = physical.sum(dim=zero_column_indices) + basis = basis[:, ~are_vectors_zero] + return physical, basis diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py new file mode 100644 index 000000000..761c580d5 --- /dev/null +++ b/src/torchjd/sparse/_linalg.py @@ -0,0 +1,196 @@ +from typing import cast + +import torch +from torch import Tensor + +# TODO: Implement in C everything in this file. + + +def extended_gcd(a: int, b: int) -> tuple[int, int, int]: + """ + Extended Euclidean Algorithm (Python integers). + Returns (g, x, y) such that a*x + b*y = g. + """ + # We perform the logic in standard Python int for speed on scalars + # then cast back to torch tensors if needed, or return python ints. + if a == 0: + return b, 0, 1 + else: + g, y, x = extended_gcd(b % a, a) + return g, x - (b // a) * y, y + + +def _get_hermite_factor_rank(H: Tensor) -> int: + """ + Computes the rank of a hermit factor matrix. + """ + col_magnitudes = torch.sum(torch.abs(H), dim=0) + return cast(int, torch.count_nonzero(col_magnitudes).item()) + + +def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor]: + """ + Computes the reduced Hermite Normal Form decomposition using PyTorch. For a matrix A (m x n) + computes the matrices H (m x r), U (n x r) and V (r x n) such that + V U = I_r + A = H V + H = A U + where r is the rank of A if reduced is True, and otherwise r is n. In the later case, this also + satisfies U V = I. + + Args: + A: (m x n) torch.Tensor (dtype=torch.long) + reduced: Reduce to rank if True. + + Returns: + H: (m x r) Canonical Lower Triangular HNF + U: (n x r) Unimodular transform (A @ U = H) + V: (r x n) Right inverse Unimodular transform (H @ V = A) + """ + + H = A.clone() + m, n = H.shape + + U = torch.eye(n, dtype=A.dtype) + V = torch.eye(n, dtype=A.dtype) + + col = 0 + + for row in range(m): + if n <= col: + break + row_slice = H[row, col:n] + nonzero_indices = torch.nonzero(row_slice) + + if nonzero_indices.numel() > 0: + relative_pivot_idx = nonzero_indices[0][0].item() + pivot_idx = col + relative_pivot_idx + else: + continue + + if pivot_idx != col: + H[:, [col, pivot_idx]] = H[:, [pivot_idx, col]] + U[:, [col, pivot_idx]] = U[:, [pivot_idx, col]] + V[[col, pivot_idx], :] = V[[pivot_idx, col], :] + + for j in range(col + 1, n): + if H[row, j] != 0: + a_val = cast(int, H[row, col].item()) + b_val = cast(int, H[row, j].item()) + + g, x, y = extended_gcd(a_val, b_val) + + c1 = -b_val // g + c2 = a_val // g + + H_col = H[:, col] + H_j = H[:, j] + + H[:, [col, j]] = torch.stack([H_col * x + H_j * y, H_col * c1 + H_j * c2], dim=1) + + U_col = U[:, col] + U_j = U[:, j] + U[:, [col, j]] = torch.stack([U_col * x + U_j * y, U_col * c1 + U_j * c2], dim=1) + + V_row_c = V[col, :] + V_row_j = V[j, :] + V[[col, j], :] = torch.stack( + [V_row_c * c2 - V_row_j * c1, V_row_c * (-y) + V_row_j * x], dim=0 + ) + + pivot_val = H[row, col] + + if pivot_val != 0: + H_row_prefix = H[row, 0:col] + factors = torch.div(H_row_prefix, pivot_val, rounding_mode="floor") + H[:, 0:col] -= factors.unsqueeze(0) * H[:, col].unsqueeze(1) + U[:, 0:col] -= factors.unsqueeze(0) * U[:, col].unsqueeze(1) + V[col, :] += factors @ V[0:col, :] + + col += 1 + + # TODO: Should actually make 2 functions, one for full and one for reduced + if reduced: + rank = _get_hermite_factor_rank(H) + + H = H[:, :rank] + U = U[:, :rank] + V = V[:rank, :] + + return H, U, V + + +def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + Computes a GCD and the projection factors, i.e. + S1 = G @ K1 + S2 = G @ K2 + with G having minimal rank r. + + Args: + S1, S2: torch.Tensors (m x n1), (m x n2) + + Returns: + G: (m x r) The Matrix GCD (Canonical Base) + K1: (r x n1) Factors for S1 + K2: (r x n2) Factors for S2 + + Implementation logic: + The concatenated matrix [S1 | S2] spans exactly the sum of the lattices generated by S1 and S2. + This is because S1 @ u1 + S2 @ u2 = [S1 | S2] @ [u1.T | u2.T].T + The reduced HNF decomposition of [S1 | S2] yields G, U, V where the G.shape[1] is the rank of + [S1 | S2] and [S1 | S2] = G @ V. This means that + S1 = G @ V[:, :m1] + S2 = G @ V[:, m1:] + This is the target common factorization. It is the greatest as the lattice spanned by G is the + same as that spanned by [S1 | S2]. + """ + assert S1.shape[0] == S2.shape[0] + m, n1 = S1.shape + + A = torch.cat([S1, S2], dim=1) + G, _, V = hnf_decomposition(A, True) + + K1 = V[:, :n1] + K2 = V[:, n1:] + + return G, K1, K2 + + +def compute_lcm(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + Computes a LCM and the projection multipliers, i.e. + L = S1 @ M1 = S2 @ M2 + with L having maximal rank r. + + Args: + S1, S2: torch.Tensors (m x n1), (m x n2) + + Returns: + L: (m x r) The Matrix LCM + M1: (n1 x r) Multiplier for S1 + M2: (n2 x r) Multiplier for S2 + + Implementation logic: + The lattice kernel of the concatenated matrix [S1 | -S2] is the set of all vectors + [u1.T | u2.T].T such that S1 @ u1 - S2 @ u2 = 0, or equivalently S1 @ u1 = S2 @ u2. + This means that the image of the components of the kernel through S1 and S2 respectively are the + same which is exactly the intersection of the lattices generated by S1 and S2. + The full HNF decomposition of [S1 | -S2] yields H, U, V such that + H = [S1 | -S2] @ U + If [S1 | -S2] has rank r', then every column of H after the first r' contain only zeros, and + therefore U[:, r':] spans the kernel of [S1 | -S2]. We have + S1 @ U[:n1, r':] = S2 @ U[n1:, r':] + which yields the desired decomposition with r=n1+n2-r'. + """ + assert S1.shape[0] == S2.shape[0] + m, n1 = S1.shape + + B = torch.cat([S1, -S2], dim=1) + H, U, _ = hnf_decomposition(B, False) + + rank = _get_hermite_factor_rank(H) + M1 = U[:n1, -rank:] + M2 = U[n1:, -rank:] + L = S1 @ M1 + return L, M1, M2 diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py new file mode 100644 index 000000000..ddc1263b5 --- /dev/null +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -0,0 +1,313 @@ +import itertools +from collections.abc import Callable +from functools import wraps +from math import prod + +import torch +from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros +from torch.utils._pytree import tree_map + + +class SparseLatticedTensor(Tensor): + _HANDLED_FUNCTIONS = dict[Callable, Callable]() + + @staticmethod + def __new__( + cls, + physical: Tensor, + basis: Tensor, + margin: Tensor | None = None, + ): + assert basis.dtype == torch.int64 + + # Note [Passing requires_grad=true tensors to subclasses] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Calling _make_subclass directly in an autograd context is + # never the right thing to do, as this will detach you from + # the autograd graph. You must create an autograd function + # representing the "constructor" (NegativeView, in this case) + # and call that instead. This assert helps prevent direct usage + # (which is bad!) + assert not physical.requires_grad or not torch.is_grad_enabled() + + if margin is None: + margin = torch.zeros([basis.shape[0], 2], dtype=torch.int64) + + pshape_t = tensor(physical.shape, dtype=torch.int64) + shape = physical_image_size(basis, pshape_t) + margin.sum(dim=1) + + return Tensor._make_wrapper_subclass( + cls, list(shape), dtype=physical.dtype, device=physical.device + ) + + def __init__( + self, + physical: Tensor, + basis: Tensor, + margin: Tensor | None, + ): + """ + This constructor is made for specifying physical and basis exactly. It should not modify + it. + + For this reason, another constructor will be made to either modify the physical / basis to + simplify the result, or to create a dense tensor directly if it's already dense. + + :param physical: The dense tensor holding the actual data. + :param basis: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing + the linear transformation between an index in the physical tensor and the corresponding + index in the virtual tensor, i.e. v_index = basis @ p_index + margin[:, 0] + :param margin: Number of extra elements at the start and end of each virtual dimension. + """ + + if margin is None: + margin = torch.zeros([basis.shape[0], 2], dtype=torch.int64) + + if any(s == 1 for s in physical.shape): + raise ValueError( + "physical must not contain any dimension of size 1. Found physical.shape=" + f"{physical.shape}." + ) + if basis.dtype is not torch.int64: + raise ValueError(f"basis should be of int64 dtype. Found basis.dtype={basis.dtype}.") + if not (basis >= 0).all(): + raise ValueError(f"All basis vectors must be non-negative. Found basis={basis}.") + if basis.shape[1] != physical.ndim: + raise ValueError( + f"basis should have 1 column per physical dimension. Found basis={basis} and " + f"physical.shape={physical.shape}." + ) + if (basis.sum(dim=0) == 0).any(): + raise ValueError( + f"basis should not have any column full of zeros. Found basis={basis}." + ) + groups = get_groupings(list(physical.shape), basis) + if any(len(group) != 1 for group in groups): + raise ValueError( + f"Dimensions must be maximally grouped. Found basis={basis} and " f"groups={groups}" + ) + + self.physical = physical + self.basis = basis + self.margin = margin + self.shape_t = tensor(self.shape, dtype=torch.int64) + self.pshape_t = tensor(physical.shape, dtype=torch.int64) + + def to_dense( + self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None + ) -> Tensor: + assert dtype is None # We may add support for this later + assert masked_grad is None # We may add support for this later + + if self.physical.ndim == 0: + return self.physical + + p_index_ranges = [arange(s) for s in self.physical.shape] + p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) + + # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu + reshaped_offset = self.offset.reshape([-1] + [1] * self.physical.ndim) + v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + reshaped_offset + # v_indices_grid is of shape [n_virtual_dims] + physical_shape + res = zeros(self.shape, device=self.device, dtype=self.dtype) + res[tuple(v_indices_grid)] = self.physical + return res + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in cls._HANDLED_FUNCTIONS: + return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) + + print_fallback(func, args, kwargs) + unwrapped_args = tree_map(unwrap_to_dense, args) + unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) + return func(*unwrapped_args, **unwrapped_kwargs) + + def __repr__(self, *, tensor_contents=None) -> str: # type: ignore[override] + return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, margin={self.margin})" + + @classmethod + def implements(cls, torch_function): + """Register a torch function override for ScalarTensor""" + + @wraps(torch_function) + def decorator(func): + cls._HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + @property + def offset(self) -> Tensor: + """ + Returns the margin at the start of each virtual dimension. Can be negative. + + The result is an int tensor of shape [virtual_ndim]. + """ + + return self.margin[:, 0] + + +impl = SparseLatticedTensor.implements + + +def min_natural_virtual_indices(basis: Tensor, pshape: Tensor) -> Tensor: + # Basis where each positive element is replaced by 0 + non_positive_basis = torch.min(basis, torch.zeros_like(basis)) + max_physical_index = pshape - 1 + return (non_positive_basis * max_physical_index.unsqueeze(0)).sum(dim=1) + + +def max_natural_virtual_indices(basis: Tensor, pshape: Tensor) -> Tensor: + # Basis where each negative element is replaced by 0 + non_negative = torch.max(basis, torch.zeros_like(basis)) + max_physical_index = pshape - 1 + return (non_negative * max_physical_index.unsqueeze(0)).sum(dim=1) + + +def physical_image_size(basis: Tensor, pshape: Tensor) -> Tensor: + """ + Returns the shape of the image of the physical through the basis transform. + + The result is an int tensor of shape [virtual_ndim]. + """ + + one = torch.ones(basis.shape[0], dtype=torch.int64) + max_idx = max_natural_virtual_indices(basis, pshape) + min_idx = min_natural_virtual_indices(basis, pshape) + return max_idx - min_idx + one + + +def print_fallback(func, args, kwargs) -> None: + def tensor_to_str(t: Tensor) -> str: + result = f"{t.__class__.__name__} - shape: {t.shape}" + if isinstance(t, SparseLatticedTensor): + result += f" - pshape: {t.physical.shape} - basis: {t.basis} - margin: {t.margin}" + + return result + + print(f"Falling back to dense for {func.__name__}") + if len(args) > 0: + print("* args:") + for arg in args: + if isinstance(arg, Tensor): + print(f" > {tensor_to_str(arg)}") + elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): + list_content = "\n ".join([tensor_to_str(t) for t in arg]) + print(f" > [{list_content}]") + else: + print(f" > {arg}") + if len(kwargs) > 0: + print("* kwargs:") + for k, v in kwargs.items(): + print(f" > {k}: {v}") + print() + + +def get_groupings(pshape: list[int], basis: Tensor) -> list[list[int]]: + basis_time_pshape = basis * tensor(pshape, dtype=torch.int64) + groups = {i: {i} for i, column in enumerate(basis.T)} + group_ids = [i for i in range(len(basis.T))] + for i1, i2 in itertools.combinations(range(basis.shape[1]), 2): + if torch.equal(basis[:, i1], basis_time_pshape[:, i2]): + groups[group_ids[i1]].update(groups[group_ids[i2]]) + group_ids[i2] = group_ids[i1] + + new_columns = [sorted(groups[group_id]) for group_id in sorted(set(group_ids))] + + if len(new_columns) != len(pshape): + print(f"Combined pshape with the following new columns: {new_columns}.") + + return new_columns + + +def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: + if isinstance(t, SparseLatticedTensor): + return t + else: + return make_slt(t, torch.eye(t.ndim, dtype=torch.int64), None) + + +def to_most_efficient_tensor( + physical: Tensor, + basis: Tensor, + margin: Tensor | None, +) -> Tensor: + physical, basis = fix_dim_of_size_1(physical, basis) + physical, basis = fix_ungrouped_dims(physical, basis) + + if (basis.sum(dim=0) == 1).all(): + print("Turning supposedly dense SLT into Tensor. This can be bugged and slow.") + # TODO: this condition is broken if basis is allowed to have negative values. It also only + # works when size is the default and offset is 0. + # TODO: this can be done more efficiently (without even creating the SLT) + return SparseLatticedTensor(physical, basis, margin).to_dense() + else: + return SparseLatticedTensor(physical, basis, margin) + + +def unwrap_to_dense(t: Tensor): + if isinstance(t, SparseLatticedTensor): + return t.to_dense() + else: + return t + + +def get_full_source(source: list[int], destination: list[int], ndim: int) -> list[int]: + """ + Doing a movedim with source and destination is always equivalent to doing a movedim with + [0, 1, ..., ndim-1] (aka "full_destination") as destination, and the "full_source" as source. + + This function computes the full_source based on a source and destination. + + Example: + source=[2, 4] + destination=[0, 3] + ndim=5 + + full_source = [2, 0, 1, 4, 3] + full_destination = [0, 1, 2, 3, 4] + """ + + idx = torch.full((ndim,), -1, dtype=torch.int64) + idx[destination] = tensor(source, dtype=torch.int64) + source_set = set(source) + idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set], dtype=torch.int64) + + return idx.tolist() + + +def fix_dim_of_size_1(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + """ + Removes physical dimensions of size one and returns the corresponding new physical and new basis + """ + + is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool) + return physical.squeeze(), basis[:, ~is_of_size_1] + + +def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + """Squash together physical dimensions that can be squashed.""" + + groups = get_groupings(list(physical.shape), basis) + nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) + basis_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) + for j, group in enumerate(groups): + basis_mapping[group[-1], j] = 1 + + new_basis = basis @ basis_mapping + return nphysical, new_basis + + +def make_slt( + physical: Tensor, + basis: Tensor, + margin: Tensor | None, +) -> SparseLatticedTensor: + """Fix physical and basis and create a SparseLatticedTensor with them.""" + + physical, basis = fix_dim_of_size_1(physical, basis) + physical, basis = fix_ungrouped_dims(physical, basis) + return SparseLatticedTensor(physical, basis, margin) diff --git a/tests/unit/sparse/__init__.py b/tests/unit/sparse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py new file mode 100644 index 000000000..0dac10b55 --- /dev/null +++ b/tests/unit/sparse/test_linalg.py @@ -0,0 +1,64 @@ +import torch +from pytest import mark +from torch import Tensor, tensor + +from torchjd.sparse._linalg import compute_lcm, hnf_decomposition + + +@mark.parametrize( + ["shape", "max_rank"], + [ + ([5, 7], 3), + ([1, 7], 1), + ([5, 1], 1), + ([7, 5], 2), + ([5, 7], 5), + ([7, 5], 5), + ], +) +@mark.parametrize("reduced", [True, False]) +def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool): + # Generate a matrix A of desired shape and rank max_rank with high probability and lower + # otherwise. + U = torch.randint(-10, 11, [shape[0], max_rank], dtype=torch.int64) + V = torch.randint(-10, 11, [max_rank, shape[1]], dtype=torch.int64) + A = U @ V + H, U, V = hnf_decomposition(A, reduced) + + r = H.shape[1] + + # Note that with these assert, the rank is typically correct as it is at most max_rank, which it + # is with high probability, and we can reconstruct A=H @ V, so the rank of H is at least that of + # A, similarly, the rank of H is at most that of A. + if reduced: + assert r <= max_rank + else: + assert torch.equal(U @ V, torch.eye(r, dtype=torch.int64)) + assert torch.equal(V @ U, torch.eye(r, dtype=torch.int64)) + assert torch.equal(H @ V, A) + assert torch.equal(A @ U, H) + + # Check H is lower triangular (its upper triangle must be zero) + assert torch.equal(torch.triu(H, diagonal=1), torch.zeros_like(H)) + + +@mark.parametrize( + ["S1", "S2"], + [ + (tensor([[8]]), tensor([[12]])), + (tensor([[8, 2]]), tensor([[12, 3]])), + (tensor([[8], [2]]), tensor([[12], [3]])), + (tensor([[8, 5]]), tensor([[12, 9]])), + (tensor([[8, 6], [4, 2]]), tensor([[16, 4], [2, 2]])), + ], +) +def test_compute_lcm(S1: Tensor, S2: Tensor): + L, M1, M2 = compute_lcm(S1, S2) + + print() + print(L) + print(M1) + print(M2) + + assert torch.equal(S1 @ M1, L) + assert torch.equal(S2 @ M2, L) diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py new file mode 100644 index 000000000..d72dea425 --- /dev/null +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -0,0 +1,392 @@ +import torch +from pytest import mark +from torch import Tensor, tensor +from torch.ops import aten # type: ignore +from torch.testing import assert_close +from utils.tensors import randn_, tensor_, zeros_ + +from torchjd.sparse._aten_function_overrides.einsum import einsum +from torchjd.sparse._aten_function_overrides.pointwise import ( + _IN_PLACE_POINTWISE_FUNCTIONS, + _POINTWISE_FUNCTIONS, +) +from torchjd.sparse._coalesce import fix_zero_basis_vectors +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, + fix_ungrouped_dims, + get_full_source, + get_groupings, +) + + +def test_to_dense(): + n = 2 + m = 3 + a = randn_([n, m]) + b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]]), margin=None) + c = b.to_dense() + + for i in range(n): + for j in range(m): + assert c[i, j, j, i] == a[i, j] + + +def test_to_dense2(): + a = tensor_([1.0, 2.0, 3.0]) + b = SparseLatticedTensor(a, tensor([[4]]), margin=None) + c = b.to_dense() + expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) + assert torch.all(torch.eq(c, expected)) + + +@mark.parametrize( + ["a_pshape", "a_basis", "b_pshape", "b_basis", "a_indices", "b_indices", "output_indices"], + [ + ( + [4, 5], + tensor([[1, 0], [1, 0], [0, 1]]), + [4, 5], + tensor([[1, 0], [0, 1], [0, 1]]), + [0, 1, 2], + [0, 2, 3], + [0, 1, 3], + ), + ( + [2, 3, 5], + tensor([[3, 1, 0], [1, 0, 2]]), + [10, 3], + tensor([[1, 0], [0, 1]]), + [0, 1], + [1, 2], + [0, 2], + ), + ( + [6, 2, 3], + tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [2, 3], + tensor([[3, 1], [1, 0], [0, 1]]), + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ), + ], +) +def test_einsum( + a_pshape: list[int], + a_basis: Tensor, + b_pshape: list[int], + b_basis: Tensor, + a_indices: list[int], + b_indices: list[int], + output_indices: list[int], +): + a = SparseLatticedTensor(randn_(a_pshape), a_basis, margin=None) + b = SparseLatticedTensor(randn_(b_pshape), b_basis, margin=None) + + res = einsum((a, a_indices), (b, b_indices), output=output_indices) + + expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) + + assert isinstance(res, SparseLatticedTensor) + assert_close(res.to_dense(), expected) + + +@mark.parametrize( + "shape", + [ + [], + [2], + [2, 3], + [2, 3, 4], + ], +) +def test_sparse_latticed_tensor_scalar(shape: list[int]): + a = randn_(shape) + b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64), margin=None) + + assert_close(a, b.to_dense()) + + +@mark.parametrize("dim", [2, 3, 4, 5, 10]) +def test_diag_equivalence(dim: int): + a = randn_([dim]) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) + + diag_a = torch.diag(a) + + assert_close(b.to_dense(), diag_a) + + +def test_three_virtual_single_physical(): + dim = 10 + a = randn_([dim]) + b = SparseLatticedTensor(a, tensor([[1], [1], [1]]), margin=None) + + expected = zeros_([dim, dim, dim]) + for i in range(dim): + expected[i, i, i] = a[i] + + assert_close(b.to_dense(), expected) + + +@mark.parametrize("func", _POINTWISE_FUNCTIONS) +def test_pointwise(func): + dim = 10 + a = randn_([dim]) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) + c = b.to_dense() + res = func(b) + assert isinstance(res, SparseLatticedTensor) + + assert_close(res.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", _IN_PLACE_POINTWISE_FUNCTIONS) +def test_inplace_pointwise(func): + dim = 10 + a = randn_([dim]) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) + c = b.to_dense() + func(b) + assert isinstance(b, SparseLatticedTensor) + + assert_close(b.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", [torch.mean, torch.sum]) +def test_unary(func): + dim = 10 + a = randn_([dim]) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) + c = b.to_dense() + + res = func(b) + assert_close(res.to_dense(), func(c)) + + +@mark.parametrize( + ["physical_shape", "basis", "target_shape", "expected_physical_shape", "expected_basis"], + [ + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # squashing 2 dims + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [12], + [2, 3], + tensor([[9, 1]]), + ), # squashing 3 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 3 dims + ( + [4], + tensor([[1], [1]]), + [2, 2, 4], + [2, 2], + tensor([[1, 0], [0, 1], [2, 1]]), + ), # unsquashing physical dim + ( + [4], + tensor([[1], [1]]), + [4, 2, 2], + [2, 2], + tensor([[2, 1], [1, 0], [0, 1]]), + ), # unsquashing physical dim + ( + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [4, 12], + [2, 12], + tensor([[3, 0], [0, 1]]), + ), # world boss + ( + [2, 12], + tensor([[3, 0], [0, 1]]), + [2, 2, 3, 4], + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ), # world boss + ], +) +def test_view( + physical_shape: list[int], + basis: Tensor, + target_shape: list[int], + expected_physical_shape: list[int], + expected_basis: Tensor, +): + a = randn_(tuple(physical_shape)) + t = SparseLatticedTensor(a, basis, margin=None) + + result = aten.view.default(t, target_shape) + expected = t.to_dense().reshape(target_shape) + + assert isinstance(result, SparseLatticedTensor) + assert list(result.physical.shape) == expected_physical_shape + assert torch.equal(result.basis, expected_basis) + assert torch.all(torch.eq(result.to_dense(), expected)) + + +@mark.parametrize( + ["pshape", "basis", "expected"], + [ + ( + [[32, 2, 3, 4, 5]], + torch.tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 60, 20, 5, 1]]), + [[0], [1, 2, 3, 4]], + ) + ], +) +def test_get_groupings(pshape: list[int], basis: torch.Tensor, expected: list[list[int]]): + result = get_groupings(pshape, basis) + assert result == expected + + +@mark.parametrize( + ["physical_shape", "basis", "expected_physical_shape", "expected_basis"], + [ + ( + [3, 4, 5], + tensor([[20, 5, 1], [4, 1, 12], [0, 0, 1]]), + [12, 5], + tensor([[5, 1], [1, 12], [0, 1]]), + ), + ( + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + ), + ([3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]]), [3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]])), + ], +) +def test_fix_ungrouped_dims( + physical_shape: list[int], + basis: Tensor, + expected_physical_shape: list[int], + expected_basis: Tensor, +): + physical = randn_(physical_shape) + fixed_physical, fixed_basis = fix_ungrouped_dims(physical, basis) + + assert list(fixed_physical.shape) == expected_physical_shape + assert torch.equal(fixed_basis, expected_basis) + + +@mark.parametrize( + [ + "source", + "destination", + "ndim", + ], + [ + ([2, 4], [0, 3], 5), + ([5, 3, 6], [2, 0, 5], 8), + ], +) +def test_get_column_indices(source: list[int], destination: list[int], ndim: int): + # TODO: this test should be improved / removed. It creates quite big tensors for nothing. + + t = randn_(list(torch.randint(3, 8, size=(ndim,)))) + full_destination = list(range(ndim)) + full_source = get_full_source(source, destination, ndim) + assert torch.equal(t.movedim(full_source, full_destination), t.movedim(source, destination)) + + +@mark.parametrize( + ["slt_args", "dim", "expected_densify"], + [ + ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1, False), + ([([3], tensor([[2]])), ([4], tensor([[2]]))], 0, True), + ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1, False), + ], +) +def test_concatenate( + slt_args: list[tuple[list[int], Tensor]], + dim: int, + expected_densify: bool, +): + tensors = [ + SparseLatticedTensor(randn_(pshape), basis, margin=None) for pshape, basis in slt_args + ] + res = aten.cat.default(tensors, dim) + expected = aten.cat.default([t.to_dense() for t in tensors], dim) + + if expected_densify: + assert not isinstance(res, SparseLatticedTensor) + else: + assert isinstance(res, SparseLatticedTensor) + + assert torch.all(torch.eq(res.to_dense(), expected)) + + +@mark.parametrize( + ["physical", "basis", "expected_physical", "expected_basis"], + [ + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 0], [1, 0], [2, 0]]), + tensor_([6, 15]), + tensor([[1], [1], [2]]), + ), + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + ), + ( + tensor_([[3, 2, 1], [6, 5, 4]]), + tensor([[0, 0], [0, 0], [0, 0]]), + tensor_(21), + tensor([[], [], []], dtype=torch.int64), + ), + ], +) +def test_fix_zero_basis_vectors( + physical: Tensor, + basis: Tensor, + expected_physical: Tensor, + expected_basis: Tensor, +): + physical, basis = fix_zero_basis_vectors(physical, basis) + assert torch.equal(physical, expected_physical) + assert torch.equal(basis, expected_basis)