Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
tags: []
---

# Limitations of embedded concat_where

- **Status**: valid
- **Authors**: Hannes Vogt (@havogt)
- **Created**: 2026-03-12
- **Updated**: 2026-03-17

In embedded execution, `concat_where` is, for now, limited to simple but common cases.

We do not support `concat_where` in cases

- where the domain would be infinite and therefore can't be represented as an ndarray, e.g. `concat_where(I < 0, 0.0, somefield)` where the scalar 0.0 would be broadcasted to a field reaching to -infinity;
- with multi-dimensional domains, e.g. `concat_where(I > 0 | J > 0, a, b)`. These cases need to be represented by a nested `concat_where(I > 0, a, concat_where(J > 0, a, b))`;
- with non-consecutive (disjoint) domain conditions, e.g. `concat_where(I != 0, a, b)`. These cases need to be expressed using nested `concat_where`, e.g. `concat_where(I < 0, a, concat_where(I > 0, a, b))`.

## Context

`concat_where` requires expressing conditions like `I != i`, which would produce two disjoint 1D domains (everything before index `i` and everything after). We need a way to represent these non-consecutive domains.

A complete implementation would require designing how to handle fields on non-hypercubic domains. Currently, `Domain` is a Cartesian product of per-dimension `UnitRange`s, which inherently describes hypercubic (rectangular) regions. Supporting arbitrary non-consecutive domains in multiple dimensions would mean fields could live on non-rectangular regions, requiring fundamental changes to field storage, slicing, and iteration.

## Decision

Non-consecutive (disjoint) domains are **not supported** in the domain expression API:

- `Dimension.__ne__(value)` raises `NotImplementedError` when called with an integer value, since it would produce two disjoint domains.
- `Domain.__or__` raises `NotImplementedError` for both multidimensional domains and for 1D domains that are disjoint (non-overlapping and non-adjacent).

The domain expression API only supports operations that result in a single contiguous `Domain`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only place in the ADR where contiguous is used. Everywhere else non-contiguous is used, is that on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah forgot to push the file, because my directory wasn't clean and my brain filtered .md files when adding stuff to commit...


## Consequences

- `concat_where` with `I != i` must be rewritten as `concat_where(I < i, ..., concat_where(I > i, ..., ...))`.
- This keeps the domain expression API simple: all supported operations return a single `Domain`.

## Alternatives considered

### General `concat_where` with multi-dimensional domain conditions

Implementation for multi-dimensional domain conditions (e.g. `(I != 2) | (K != 5)`) and full support for domain operations in `concat_where` would require

1. **A `DomainTuple` class** with full algebra: a `tuple` subclass carrying `__and__`, `__or__`, `__rand__`, `__ror__` operators so that expressions like `tuple & Domain`, `Domain & tuple`, and `tuple | tuple` all work.

2. **Normalization of domain tuples**: We need to design `DomainTuple` invariants, e.g.

- Should all domains be promoted to the same rank (missing dimensions filled with infinite ranges)?
- Should we reduce overlapping domains to non-overlapping via box subtraction?

Before implementing a complex `DomainTuple`, we should conclude on (if we want) a concept of non-consecutive fields.
2 changes: 1 addition & 1 deletion docs/development/ADRs/next/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Writing a new ADR is simple:

### Embedded Execution

_None_
- [0022 - Limitations of embedded concat_where](0022-Limitations-of-embedded-concat_where.md)

### Transformations

Expand Down
69 changes: 69 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,49 @@ def __add__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),))

def __ge__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),))

def __lt__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),))

def __le__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),))

@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this works, but IIUC, this is the implemented behavior. If this works, a similar change could be applied to __ne__

Suggested change
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> Literal[False]: ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't like it: Overloaded function signatures 1 and 3 overlap with incompatible return types

def __eq__(self, value: object) -> bool | Domain:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait a second, do we really need to add the overload for object? It doesn't make sense because it overlaps with everything. I think the correct type hint is:

Suggested change
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> bool | Domain:
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want/need to support Dimension(...) == some_other_object comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with your proposal, I should remove the else branch, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But does it makes sense to support Dimension == Any? Shouldn't that just raise a TypeError? Something like:

Suggested change
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> bool | Domain:
if isinstance(value, Dimension):
return self.value == value.value
if isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),))
raise TypeError(....)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have cases, but I can try...

if isinstance(value, Dimension):
return self.value == value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),))
else:
return False

@overload
def __ne__(self, value: Dimension) -> bool: ...
@overload
def __ne__(self, value: object) -> bool: ...
def __ne__(self, value: object) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

if isinstance(value, Dimension):
return self.value != value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
raise NotImplementedError(
"'Dimension.__ne__' with an integer value produces two disjoint domains, "
"which is not supported. Use 'concat_where(dim < value, ...) "
"concat_where(dim > value, ...)' to express the condition, see ADR 22."
)
else:
return True


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -499,6 +542,32 @@ def __and__(self, other: Domain) -> Domain:
)
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __or__(self, other: Domain) -> Domain:
"""
Union of `Domain`s, currently limited to 1D overlapping or adjacent domains.

Raises `NotImplementedError` for multidimensional domains or disjoint 1D domains.
See ADR 22.
"""
if self.ndim > 1 or other.ndim > 1:
raise NotImplementedError(
"Union of multidimensional domains is not supported, see ADR 22."
)
if self.ndim == 0:
return other
if other.ndim == 0:
return self
first, second = sorted((self, other), key=lambda x: x.ranges[0].start)
if first.ranges[0].stop >= second.ranges[0].start:
return Domain(
dims=(self.dims[0],),
ranges=(UnitRange(first.ranges[0].start, second.ranges[0].stop),),
)
raise NotImplementedError(
f"Union of disjoint domains '{first}' and '{second}' is not supported. "
f"Use nested 'concat_where' to express non-consecutive conditions, see ADR 22."
)

@functools.cached_property
def slice_at(self) -> utils.IndexerCallable[slice, Domain]:
"""
Expand Down
85 changes: 85 additions & 0 deletions tests/next_tests/unit_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,88 @@ def test_for_relocation(self):
assert result.domain_dim == I_half
assert result.codomain == I
assert result.offset == 0


class TestDimensionComparisonOperators:
"""Test Dimension comparison operators return correct Domain objects."""

def test_gt(self):
result = IDim > 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(4, Infinity.POSITIVE),))

def test_ge(self):
result = IDim >= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, Infinity.POSITIVE),))

def test_lt(self):
result = IDim < 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 3),))

def test_le(self):
result = IDim <= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 4),))

def test_eq_int(self):
result = IDim == 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, 4),))

def test_ne_int(self):
"""Dimension.__ne__ with int raises NotImplementedError."""
with pytest.raises(NotImplementedError):
IDim != 3

def test_reverse_gt(self):
assert (5 > IDim) == (IDim < 5)

def test_reverse_ge(self):
assert (5 >= IDim) == (IDim <= 5)

def test_reverse_lt(self):
assert (5 < IDim) == (IDim > 5)

def test_reverse_le(self):
assert (5 <= IDim) == (IDim >= 5)

def test_reverse_eq(self):
assert (3 == IDim) == (IDim == 3)

def test_reverse_ne(self):
with pytest.raises(NotImplementedError):
3 != IDim


class TestDomainAndOperator:
"""Test Domain.__and__ (intersection)."""

def test_same_dim(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
assert (d1 & d2) == Domain(dims=(IDim,), ranges=(UnitRange(3, 5),))

def test_different_dims(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(JDim,), ranges=(UnitRange(2, 4),))
result = d1 & d2
assert result == Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 5), UnitRange(2, 4)))


class TestDomainOrOperator:
"""Test Domain.__or__ (union) — 1D only."""

def test_same_dim_overlapping(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
result = d1 | d2
assert result == Domain(dims=(IDim,), ranges=(UnitRange(0, 8),))

def test_same_dim_disjoint_raises(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 3),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(5, 8),))
with pytest.raises(NotImplementedError):
d1 | d2

def test_multidim_raises(self):
d1 = Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 3), UnitRange(0, 3)))
d2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 8), UnitRange(5, 8)))
with pytest.raises(NotImplementedError):
d1 | d2
Loading