Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
tags: []
---

# Limitations of embedded concat_where

- **Status**: valid
- **Authors**: Hannes Vogt (@havogt)
- **Created**: 2026-03-12
- **Updated**: 2026-03-17

In embedded execution, `concat_where` is, for now, limited to simple but common cases.

We do not support `concat_where` in cases

- where the domain would be infinite and therefore can't be represented as an ndarray, e.g. `concat_where(I < 0, 0.0, somefield)` where the scalar 0.0 would be broadcasted to a field reaching to -infinity;
- with multi-dimensional domains, e.g. `concat_where(I > 0 | J > 0, a, b)`. These cases need to be represented by a nested `concat_where(I > 0, a, concat_where(J > 0, a, b))`;
- with non-contiguous (disjoint) domain conditions, e.g. `concat_where(I != 0, a, b)`. These cases need to be expressed using nested `concat_where`, e.g. `concat_where(I < 0, a, concat_where(I > 0, a, b))`.

## Context

`concat_where` requires expressing conditions like `I != i`, which would produce two disjoint 1D domains (everything before index `i` and everything after). We need a way to represent these non-contiguous domains.

A complete implementation would require designing how to handle fields on non-hypercubic domains. Currently, `Domain` is a Cartesian product of per-dimension `UnitRange`s, which inherently describes hypercubic (rectangular) regions. Supporting arbitrary non-contiguous domains in multiple dimensions would mean fields could live on non-rectangular regions, requiring fundamental changes to field storage, slicing, and iteration.

## Decision

Non-contiguous (disjoint) domains are **not supported** in the domain expression API:

- `Dimension.__ne__(value)` raises `NotImplementedError` when called with an integer value, since it would produce two disjoint domains.
- `Domain.__or__` raises `NotImplementedError` for both multidimensional domains and for 1D domains that are disjoint (non-overlapping and non-adjacent).

The domain expression API only supports operations that result in a single contiguous `Domain`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only place in the ADR where contiguous is used. Everywhere else non-contiguous is used, is that on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah forgot to push the file, because my directory wasn't clean and my brain filtered .md files when adding stuff to commit...


## Consequences

- `concat_where` with `I != i` must be rewritten as `concat_where(I < i, ..., concat_where(I > i, ..., ...))`.
- This keeps the domain expression API simple: all supported operations return a single `Domain`.

## Alternatives considered

### General `concat_where` with multi-dimensional domain conditions

Implementation for multi-dimensional domain conditions (e.g. `(I != 2) | (K != 5)`) and full support for domain operations in `concat_where` would require

1. **A `DomainTuple` class** with full algebra: a `tuple` subclass carrying `__and__`, `__or__`, `__rand__`, `__ror__` operators so that expressions like `tuple & Domain`, `Domain & tuple`, and `tuple | tuple` all work.

2. **Normalization of domain tuples**: We need to design `DomainTuple` invariants, e.g.

- Should all domains be promoted to the same rank (missing dimensions filled with infinite ranges)?
- Should we reduce overlapping domains to non-overlapping via box subtraction?

Before implementing a complex `DomainTuple`, we should conclude on (if we want) a concept of non-contiguous fields.
2 changes: 1 addition & 1 deletion docs/development/ADRs/next/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Writing a new ADR is simple:

### Embedded Execution

_None_
- [0022 - Limitations of embedded concat_where](0022-Limitations-of-embedded-concat_where.md)

### Transformations

Expand Down
71 changes: 71 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,47 @@ def __add__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),))

def __ge__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),))

def __lt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),))

def __le__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),))

@overload # type: ignore[override] # incompatible with supertype `object.__eq__` which returns `bool`.
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ...
def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:
if isinstance(value, Dimension):
return self.value == value.value
if isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),))
# This will fallback to default identity comparison if reflection also returns `NotImplemented`,
# which does identity comparison, see https://docs.python.org/3/reference/datamodel.html#object.__eq__.
return NotImplemented

@overload # type: ignore[override] # incompatible with supertype `object.__ne__` which returns `bool`.
def __ne__(self, value: Dimension) -> bool: ...
@overload
def __ne__(self, value: core_defs.IntegralScalar) -> Domain: ...
def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:
if isinstance(value, Dimension):
return self.value != value.value
if isinstance(value, core_defs.INTEGRAL_TYPES):
raise NotImplementedError(
"'Dimension.__ne__' with an integer value produces two disjoint domains, "
"which is not supported. Use 'concat_where(dim < value, ...) "
"concat_where(dim > value, ...)' to express the condition, see ADR 22."
)
return NotImplemented


if TYPE_CHECKING:
# These exist as on-the fly replacements for Dimension instances
Expand Down Expand Up @@ -521,6 +562,36 @@ def __and__(self, other: Domain) -> Domain:
)
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __or__(self, other: Domain) -> Domain:
"""
Union of `Domain`s, currently limited to 1D overlapping or adjacent domains.

Raises `NotImplementedError` for multidimensional domains or disjoint 1D domains.
See ADR 22.
"""
if self.ndim > 1 or other.ndim > 1:
raise NotImplementedError(
"Union of multidimensional domains is not supported, see ADR 22."
)
if self.ndim == 0:
return other
if other.ndim == 0:
return self
if self.dims[0] != other.dims[0]:
raise NotImplementedError(
f"Union of 1D domains with different dimensions '{self.dims[0]}' and '{other.dims[0]}' is not supported."
)
first, second = sorted((self, other), key=lambda x: x.ranges[0].start)
if first.ranges[0].stop >= second.ranges[0].start:
return Domain(
dims=(self.dims[0],),
ranges=(UnitRange(first.ranges[0].start, second.ranges[0].stop),),
)
raise NotImplementedError(
f"Union of disjoint domains '{first}' and '{second}' is not supported. "
f"Use nested 'concat_where' to express non-contiguous conditions, see ADR 22."
)

@functools.cached_property
def slice_at(self) -> utils.IndexerCallable[slice, Domain]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_local_view(
# The invariant below is ensured by calling `make_field()` to construct `FieldopData`.
# The `make_field` constructor converts any local dimension, if present, to `ListType`
# element type, while leaving the field domain with all global dimensions.
assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims)
assert all(dim.kind != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims)
domain_dims = [domain_range.dim for domain_range in domain]
domain_indices = gtir_domain.get_element_subset(
domain_dims, origin=None
Expand Down
91 changes: 91 additions & 0 deletions tests/next_tests/unit_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,94 @@ def test_for_relocation(self):
assert result.domain_dim == I_half
assert result.codomain == I
assert result.offset == 0


class TestDimensionComparisonOperators:
"""Test Dimension comparison operators return correct Domain objects."""

def test_gt(self):
result = IDim > 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(4, Infinity.POSITIVE),))

def test_ge(self):
result = IDim >= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, Infinity.POSITIVE),))

def test_lt(self):
result = IDim < 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 3),))

def test_le(self):
result = IDim <= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 4),))

def test_eq_int(self):
result = IDim == 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, 4),))

def test_ne_int(self):
"""Dimension.__ne__ with int raises NotImplementedError."""
with pytest.raises(NotImplementedError):
IDim != 3

def test_reverse_gt(self):
assert (5 > IDim) == (IDim < 5)

def test_reverse_ge(self):
assert (5 >= IDim) == (IDim <= 5)

def test_reverse_lt(self):
assert (5 < IDim) == (IDim > 5)

def test_reverse_le(self):
assert (5 <= IDim) == (IDim >= 5)

def test_reverse_eq(self):
assert (3 == IDim) == (IDim == 3)

def test_reverse_ne(self):
with pytest.raises(NotImplementedError):
3 != IDim


class TestDomainAndOperator:
"""Test Domain.__and__ (intersection)."""

def test_same_dim(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
assert (d1 & d2) == Domain(dims=(IDim,), ranges=(UnitRange(3, 5),))

def test_different_dims(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(JDim,), ranges=(UnitRange(2, 4),))
result = d1 & d2
assert result == Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 5), UnitRange(2, 4)))


class TestDomainOrOperator:
"""Test Domain.__or__ (union) — 1D only."""

def test_same_dim_overlapping(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
result = d1 | d2
assert result == Domain(dims=(IDim,), ranges=(UnitRange(0, 8),))

def test_same_dim_disjoint_raises(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 3),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(5, 8),))
with pytest.raises(NotImplementedError):
d1 | d2

def test_multidim_raises(self):
d1 = Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 3), UnitRange(0, 3)))
d2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 8), UnitRange(5, 8)))
with pytest.raises(NotImplementedError):
d1 | d2

def test_different_dims_raises(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(JDim,), ranges=(UnitRange(3, 10),))
with pytest.raises(NotImplementedError, match="different dimensions"):
d1 | d2
Loading