Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
---
tags: []
---

# Limitations of embedded concat_where

- **Status**: valid
- **Authors**: Hannes Vogt (@havogt)
- **Created**: 2026-03-12
- **Updated**: 2026-03-12

In embedded execution, `concat_where` is, for now, limited to simple but common cases.

We do not support `concat_where` in cases

- where the domain would be infinite and therefore can't be represented as an ndarray, e.g. `concat_where(I < 0, 0.0, somefield)` where the scalar 0.0 would be broadcasted to a field reaching to -infinity;
- with multi-dimensional domains, e.g. `concat_where(I > 0 | J > 0, a, b)`. These cases need to be represented by a nested `concat_where(I > 0, a, concat_where(J > 0, a, b))`;

Additionally, we only support the most basic cases of non-consecutive domains: a tuple of `Domain`s resulting from e.g. `I != 0` or the equivalent `I < 0 | I > 0`. Operations on tuples of `Domain`s are not supported.

## Context

`concat_where` requires expressing conditions like `I != i`, which produces two disjoint 1D domains (everything before index `i` and everything after). We need a way to represent these non-consecutive domains.

A complete implementation would require designing how to handle fields on non-hypercubic domains. Currently, `Domain` is a Cartesian product of per-dimension `UnitRange`s, which inherently describes hypercubic (rectangular) regions. Supporting arbitrary non-consecutive domains in multiple dimensions would mean fields could live on non-rectangular regions, requiring fundamental changes to field storage, slicing, and iteration.

## Decision

We use a simple tuple-of-`Domain` representation for non-consecutive domains, restricted to:

- **1D only**: `Domain.__or__` raises `NotImplementedError` for multidimensional domains.
- **At most 2 domains**: `Dimension.__ne__` produces exactly 2 disjoint domains. No attempt is made to support arbitrary numbers of disjoint regions.

This is sufficient for the most common `concat_where` use case (`I != i` splits a dimension into two parts) without requiring a general solution for non-hypercubic domains/fields.

## Consequences

- `concat_where` works for 1D domain conditions, which covers the primary use case of vertical level exclusion.
- Combining multiple exclusions (e.g. `(I != 2) & (I != 5)`) is not supported because it would require a custom tuple type to implement the intersection/union operations.

## Alternatives considered

### General `concat_where` with multi-dimensional domain conditions

Implementation for multi-dimensional domain conditions (e.g. `(I != 2) | (K != 5)`) and full support for domain operations in `concat_where` would require

1. **A `DomainTuple` class** with full algebra: a `tuple` subclass carrying `__and__`, `__or__`, `__rand__`, `__ror__` operators so that expressions like `tuple & Domain`, `Domain & tuple`, and `tuple | tuple` all work.

2. **Normalization of domain tuples**: We need to design `DomainTuple` invariants, e.g.

- Should all domains be promoted to the same rank (missing dimensions filled with infinite ranges)?
- Should we reduce overlapping domains to non-overlapping via box subtraction?

Before implementing a complex `DomainTuple`, we should conclude on (if we want) a concept of non-consecutive fields.
2 changes: 1 addition & 1 deletion docs/development/ADRs/next/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Writing a new ADR is simple:

### Embedded Execution

_None_
- [0022 - Limitations of embedded concat_where](0022-Limitations-of-embedded-concat_where.md)

### Transformations

Expand Down
70 changes: 70 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,54 @@ def __add__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),))

def __ge__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),))

def __lt__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),))

def __le__(self, value: core_defs.IntegralScalar) -> Domain: # type: ignore[misc] # returns Domain, not bool
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),))

@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this works, but IIUC, this is the implemented behavior. If this works, a similar change could be applied to __ne__

Suggested change
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> Literal[False]: ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't like it: Overloaded function signatures 1 and 3 overlap with incompatible return types

def __eq__(self, value: object) -> bool | Domain:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait a second, do we really need to add the overload for object? It doesn't make sense because it overlaps with everything. I think the correct type hint is:

Suggested change
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> bool | Domain:
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want/need to support Dimension(...) == some_other_object comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with your proposal, I should remove the else branch, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But does it makes sense to support Dimension == Any? Shouldn't that just raise a TypeError? Something like:

Suggested change
@overload
def __eq__(self, value: Dimension) -> bool: ...
@overload
def __eq__(self, value: core_defs.IntegralScalar) -> Domain: ... # type: ignore[overload-overlap] # intentionally returns Domain, not bool
@overload
def __eq__(self, value: object) -> bool: ...
def __eq__(self, value: object) -> bool | Domain:
if isinstance(value, Dimension):
return self.value == value.value
if isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),))
raise TypeError(....)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have cases, but I can try...

if isinstance(value, Dimension):
return self.value == value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return Domain(dims=(self,), ranges=(UnitRange(int_value, int_value + 1),))
else:
return False

@overload
def __ne__(self, value: Dimension) -> bool: ...
@overload
def __ne__(self, value: core_defs.IntegralScalar) -> tuple[Domain, Domain]: ... # type: ignore[overload-overlap] # intentionally returns tuple[Domain, Domain], not bool
@overload
def __ne__(self, value: object) -> bool: ...
def __ne__(self, value: object) -> bool | tuple[Domain, Domain]:
# TODO(havogt): Non-consecutive domains are represented as tuples, this limits domain operations to
# simple, common cases. E.g. `I != 3 & I != 5` can't be handled as it would require a custom
# domain tuple type and operations on it, see ADR 22.
if isinstance(value, Dimension):
return self.value != value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
int_value = cast(core_defs.IntegralScalar, value)
return (
Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, int_value),)),
Domain(dims=(self,), ranges=(UnitRange(int_value + 1, Infinity.POSITIVE),)),
)
else:
return True


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -499,6 +547,28 @@ def __and__(self, other: Domain) -> Domain:
)
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __or__(self, other: Domain) -> Domain | tuple[Domain, Domain]:
"""
Union of `Domain`s, currently limited to 1D domains.

Returns a single `Domain` if the ranges overlap or are adjacent,
otherwise returns a tuple of two disjoint `Domain`s.
"""
if self.ndim > 1 or other.ndim > 1:
# TODO(havogt): Domain union is currently limited to 1D domains, see ADR 22.
raise NotImplementedError("Union of multidimensional domains is not supported.")
if self.ndim == 0:
return other
if other.ndim == 0:
return self
first, second = sorted((self, other), key=lambda x: x.ranges[0].start)
if first.ranges[0].stop >= second.ranges[0].start:
return Domain(
dims=(self.dims[0],),
ranges=(UnitRange(first.ranges[0].start, second.ranges[0].stop),),
)
return (first, second)

@functools.cached_property
def slice_at(self) -> utils.IndexerCallable[slice, Domain]:
"""
Expand Down
88 changes: 88 additions & 0 deletions tests/next_tests/unit_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,91 @@ def test_for_relocation(self):
assert result.domain_dim == I_half
assert result.codomain == I
assert result.offset == 0


class TestDimensionComparisonOperators:
"""Test Dimension comparison operators return correct Domain objects."""

def test_gt(self):
result = IDim > 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(4, Infinity.POSITIVE),))

def test_ge(self):
result = IDim >= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, Infinity.POSITIVE),))

def test_lt(self):
result = IDim < 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 3),))

def test_le(self):
result = IDim <= 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 4),))

def test_eq_int(self):
result = IDim == 3
assert result == Domain(dims=(IDim,), ranges=(UnitRange(3, 4),))

def test_ne_int(self):
"""Dimension.__ne__ with int returns tuple of two Domains."""
result = IDim != 3
assert isinstance(result, tuple)
assert len(result) == 2
assert result[0] == Domain(dims=(IDim,), ranges=(UnitRange(Infinity.NEGATIVE, 3),))
assert result[1] == Domain(dims=(IDim,), ranges=(UnitRange(4, Infinity.POSITIVE),))

def test_reverse_gt(self):
assert (5 > IDim) == (IDim < 5)

def test_reverse_ge(self):
assert (5 >= IDim) == (IDim <= 5)

def test_reverse_lt(self):
assert (5 < IDim) == (IDim > 5)

def test_reverse_le(self):
assert (5 <= IDim) == (IDim >= 5)

def test_reverse_eq(self):
assert (3 == IDim) == (IDim == 3)

def test_reverse_ne(self):
assert (3 != IDim) == (IDim != 3)


class TestDomainAndOperator:
"""Test Domain.__and__ (intersection)."""

def test_same_dim(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
assert (d1 & d2) == Domain(dims=(IDim,), ranges=(UnitRange(3, 5),))

def test_different_dims(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(JDim,), ranges=(UnitRange(2, 4),))
result = d1 & d2
assert result == Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 5), UnitRange(2, 4)))


class TestDomainOrOperator:
"""Test Domain.__or__ (union) — 1D only."""

def test_same_dim_overlapping(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 5),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(3, 8),))
result = d1 | d2
assert result == Domain(dims=(IDim,), ranges=(UnitRange(0, 8),))

def test_same_dim_disjoint(self):
d1 = Domain(dims=(IDim,), ranges=(UnitRange(0, 3),))
d2 = Domain(dims=(IDim,), ranges=(UnitRange(5, 8),))
result = d1 | d2
assert isinstance(result, tuple)
assert len(result) == 2

def test_multidim_raises(self):
d1 = Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 3), UnitRange(0, 3)))
d2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 8), UnitRange(5, 8)))
with pytest.raises(NotImplementedError):
d1 | d2
Loading