diff --git a/docs/development/ADRs/next/0022-Limitations-of-embedded-concat_where.md b/docs/development/ADRs/next/0022-Limitations-of-embedded-concat_where.md new file mode 100644 index 0000000000..c625e65277 --- /dev/null +++ b/docs/development/ADRs/next/0022-Limitations-of-embedded-concat_where.md @@ -0,0 +1,53 @@ +--- +tags: [] +--- + +# Limitations of embedded concat_where + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2026-03-12 +- **Updated**: 2026-03-17 + +In embedded execution, `concat_where` is, for now, limited to simple but common cases. + +We do not support `concat_where` in cases + +- where the domain would be infinite and therefore can't be represented as an ndarray, e.g. `concat_where(I < 0, 0.0, somefield)` where the scalar 0.0 would be broadcasted to a field reaching to -infinity; +- with multi-dimensional domains, e.g. `concat_where(I > 0 | J > 0, a, b)`. These cases need to be represented by a nested `concat_where(I > 0, a, concat_where(J > 0, a, b))`; +- with non-contiguous (disjoint) domain conditions, e.g. `concat_where(I != 0, a, b)`. These cases need to be expressed using nested `concat_where`, e.g. `concat_where(I < 0, a, concat_where(I > 0, a, b))`. + +## Context + +`concat_where` requires expressing conditions like `I != i`, which would produce two disjoint 1D domains (everything before index `i` and everything after). We need a way to represent these non-contiguous domains. + +A complete implementation would require designing how to handle fields on non-hypercubic domains. Currently, `Domain` is a Cartesian product of per-dimension `UnitRange`s, which inherently describes hypercubic (rectangular) regions. Supporting arbitrary non-contiguous domains in multiple dimensions would mean fields could live on non-rectangular regions, requiring fundamental changes to field storage, slicing, and iteration. + +## Decision + +Non-contiguous (disjoint) domains are **not supported** in the domain expression API: + +- `Dimension.__ne__(value)` raises `NotImplementedError` when called with an integer value, since it would produce two disjoint domains. +- `Domain.__or__` raises `NotImplementedError` for both multidimensional domains and for 1D domains that are disjoint (non-overlapping and non-adjacent). + +The domain expression API only supports operations that result in a single contiguous `Domain`. + +## Consequences + +- `concat_where` with `I != i` must be rewritten as `concat_where(I < i, ..., concat_where(I > i, ..., ...))`. +- This keeps the domain expression API simple: all supported operations return a single `Domain`. + +## Alternatives considered + +### General `concat_where` with multi-dimensional domain conditions + +Implementation for multi-dimensional domain conditions (e.g. `(I != 2) | (K != 5)`) and full support for domain operations in `concat_where` would require + +1. **A `DomainTuple` class** with full algebra: a `tuple` subclass carrying `__and__`, `__or__`, `__rand__`, `__ror__` operators so that expressions like `tuple & Domain`, `Domain & tuple`, and `tuple | tuple` all work. + +2. **Normalization of domain tuples**: We need to design `DomainTuple` invariants, e.g. + + - Should all domains be promoted to the same rank (missing dimensions filled with infinite ranges)? + - Should we reduce overlapping domains to non-overlapping via box subtraction? + +Before implementing a complex `DomainTuple`, we should conclude on (if we want) a concept of non-contiguous fields. diff --git a/docs/development/ADRs/next/README.md b/docs/development/ADRs/next/README.md index 163a7c7661..6a133aba75 100644 --- a/docs/development/ADRs/next/README.md +++ b/docs/development/ADRs/next/README.md @@ -32,7 +32,7 @@ Writing a new ADR is simple: ### Embedded Execution -_None_ +- [0022 - Limitations of embedded concat_where](0022-Limitations-of-embedded-concat_where.md) ### Transformations diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 69c8d01217..69d42d3dea 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -108,6 +108,47 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) + def __gt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) + + def __ge__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) + + def __lt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) + + def __le__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) + + @overload # type: ignore[override] # incompatible with supertype `object.__eq__` which returns `bool`. + def __eq__(self, value: Dimension) -> bool: ... + @overload + def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... + def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: + if isinstance(value, Dimension): + return self.value == value.value + if isinstance(value, core_defs.INTEGRAL_TYPES): + int_value = cast(core_defs.IntegralScalar, value) + return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),)) + # This will fallback to default identity comparison if reflection also returns `NotImplemented`, + # which does identity comparison, see https://docs.python.org/3/reference/datamodel.html#object.__eq__. + return NotImplemented + + @overload # type: ignore[override] # incompatible with supertype `object.__ne__` which returns `bool`. + def __ne__(self, value: Dimension) -> bool: ... + @overload + def __ne__(self, value: core_defs.IntegralScalar) -> Domain: ... + def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: + if isinstance(value, Dimension): + return self.value != value.value + if isinstance(value, core_defs.INTEGRAL_TYPES): + raise NotImplementedError( + "'Dimension.__ne__' with an integer value produces two disjoint domains, " + "which is not supported. Use 'concat_where(dim < value, ...) " + "concat_where(dim > value, ...)' to express the condition, see ADR 22." + ) + return NotImplemented + if TYPE_CHECKING: # These exist as on-the fly replacements for Dimension instances @@ -521,6 +562,36 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) + def __or__(self, other: Domain) -> Domain: + """ + Union of `Domain`s, currently limited to 1D overlapping or adjacent domains. + + Raises `NotImplementedError` for multidimensional domains or disjoint 1D domains. + See ADR 22. + """ + if self.ndim > 1 or other.ndim > 1: + raise NotImplementedError( + "Union of multidimensional domains is not supported, see ADR 22." + ) + if self.ndim == 0: + return other + if other.ndim == 0: + return self + if self.dims[0] != other.dims[0]: + raise NotImplementedError( + f"Union of 1D domains with different dimensions '{self.dims[0]}' and '{other.dims[0]}' is not supported." + ) + first, second = sorted((self, other), key=lambda x: x.ranges[0].start) + if first.ranges[0].stop >= second.ranges[0].start: + return Domain( + dims=(self.dims[0],), + ranges=(UnitRange(first.ranges[0].start, second.ranges[0].stop),), + ) + raise NotImplementedError( + f"Union of disjoint domains '{first}' and '{second}' is not supported. " + f"Use nested 'concat_where' to express non-contiguous conditions, see ADR 22." + ) + @functools.cached_property def slice_at(self) -> utils.IndexerCallable[slice, Domain]: """ diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index fdefa1b849..83cd7660d8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -73,7 +73,7 @@ def get_local_view( # The invariant below is ensured by calling `make_field()` to construct `FieldopData`. # The `make_field` constructor converts any local dimension, if present, to `ListType` # element type, while leaving the field domain with all global dimensions. - assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) + assert all(dim.kind != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) domain_dims = [domain_range.dim for domain_range in domain] domain_indices = gtir_domain.get_element_subset( domain_dims, origin=None diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index eff6fcc0a5..cc1c5505d9 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -676,3 +676,94 @@ def test_for_relocation(self): assert result.domain_dim == I_half assert result.codomain == I assert result.offset == 0 + + +class TestDimensionComparisonOperators: + """Test Dimension comparison operators return correct Domain objects.""" + + def test_gt(self): + result = IDim > 3 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(4, Infinity.POSITIVE),)) + + def test_ge(self): + result = IDim >= 3 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, Infinity.POSITIVE),)) + + def test_lt(self): + result = IDim < 3 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 3),)) + + def test_le(self): + result = IDim <= 3 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 4),)) + + def test_eq_int(self): + result = IDim == 3 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, 4),)) + + def test_ne_int(self): + """Dimension.__ne__ with int raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + IDim != 3 + + def test_reverse_gt(self): + assert (5 > IDim) == (IDim < 5) + + def test_reverse_ge(self): + assert (5 >= IDim) == (IDim <= 5) + + def test_reverse_lt(self): + assert (5 < IDim) == (IDim > 5) + + def test_reverse_le(self): + assert (5 <= IDim) == (IDim >= 5) + + def test_reverse_eq(self): + assert (3 == IDim) == (IDim == 3) + + def test_reverse_ne(self): + with pytest.raises(NotImplementedError): + 3 != IDim + + +class TestDomainAndOperator: + """Test Domain.__and__ (intersection).""" + + def test_same_dim(self): + d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),)) + d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),)) + assert (d1 & d2) == Domain(dims=(IDim,), ranges=(UnitRange(3, 5),)) + + def test_different_dims(self): + d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),)) + d2 = Domain(dims=(JDim,), ranges=(UnitRange(2, 4),)) + result = d1 & d2 + assert result == Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 5), UnitRange(2, 4))) + + +class TestDomainOrOperator: + """Test Domain.__or__ (union) — 1D only.""" + + def test_same_dim_overlapping(self): + d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),)) + d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),)) + result = d1 | d2 + assert result == Domain(dims=(IDim,), ranges=(UnitRange(0, 8),)) + + def test_same_dim_disjoint_raises(self): + d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 3),)) + d2 = Domain(dims=(IDim,), ranges=(UnitRange(5, 8),)) + with pytest.raises(NotImplementedError): + d1 | d2 + + def test_multidim_raises(self): + d1 = Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 3), UnitRange(0, 3))) + d2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 8), UnitRange(5, 8))) + with pytest.raises(NotImplementedError): + d1 | d2 + + def test_different_dims_raises(self): + d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),)) + d2 = Domain(dims=(JDim,), ranges=(UnitRange(3, 10),)) + with pytest.raises(NotImplementedError, match="different dimensions"): + d1 | d2