Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ markers = [
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'uses_concat_where: tests that use the concat_where builtin',
'embedded_concat_where_infinite_domain: tests with concat_where resulting in an infinite domain',
'uses_program_metrics: tests that require backend support for program metrics',
'uses_program_with_sliced_out_arguments: tests that use a sliced argument which is not supported for non-mutable arrays, e.g. JAX',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
Expand Down
156 changes: 62 additions & 94 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,7 @@
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import (
ClassVar,
Iterable,
Never,
Optional,
ParamSpec,
TypeAlias,
TypeVar,
cast,
)
from gt4py.eve.extended_typing import ClassVar, Never, Optional, ParamSpec, TypeAlias, TypeVar, cast
from gt4py.next import common, utils
from gt4py.next.embedded import (
common as embedded_common,
Expand Down Expand Up @@ -820,39 +811,6 @@ def _hyperslice(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_slices(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, slice]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices."""
# TODO: does it make sense to upgrade this naive algorithm to numpy?
assert mask.ndim == 1
cur = bool(mask[0].item())
ind = 0
res = []
for i in range(1, mask.shape[0]):
# Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
if (mask_i := bool(mask[i].item())) != cur:
res.append((cur, slice(ind, i)))
cur = mask_i
ind = i
res.append((cur, slice(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
"""Remove empty domains from beginning and end of the list."""
lst = list(lst)
if not lst:
return lst
if lst[0][1].is_empty():
return _trim_empty_domains(lst[1:])
if lst[-1][1].is_empty():
return _trim_empty_domains(lst[:-1])
return lst


def _to_field(
value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField]
) -> common.Field:
Expand Down Expand Up @@ -906,85 +864,95 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c

def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered
# currently only concatenate along the given dimension
sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start)

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty()
len(sorted_fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty()
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
nd_array_class = _get_nd_array_class(*sorted_fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
[
nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape)
for f in sorted_fields
],
axis=new_domain.dim_index(dim, allow_missing=False),
),
domain=new_domain,
)


def _invert_domain(domain: common.Domain) -> tuple[common.Domain, ...]:
assert domain.ndim == 1
dim = domain.dims[0]
rng = domain.ranges[0]

result = []
if rng.start is not common.Infinity.NEGATIVE:
result.append(
common.Domain(
dims=(dim,), ranges=(common.UnitRange(common.Infinity.NEGATIVE, rng.start),)
)
)
if rng.stop is not common.Infinity.POSITIVE:
result.append(
common.Domain(
dims=(dim,), ranges=(common.UnitRange(rng.stop, common.Infinity.POSITIVE),)
)
)
return tuple(result)


def _size0_field(
nd_array_class: type[NdArrayField], dims: tuple[common.Dimension, ...], dtype: core_defs.DType
) -> NdArrayField:
return nd_array_class.from_array(
nd_array_class.array_ns.empty((0,) * len(dims), dtype=dtype.scalar_type),
domain=common.Domain(dims=dims, ranges=(common.UnitRange(0, 0),) * len(dims)),
)


def _concat_where(
mask_field: common.Field, true_field: common.Field, false_field: common.Field
mask: common.Domain,
true_field: common.Field,
false_field: common.Field,
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
if mask.ndim != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
)
mask_dim = mask_field.domain.dims[0]
mask_dim = mask.dims[0]

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices(
mask_field.ndarray
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
(mask, mask_field.domain.slice_at[domain_slice])
for mask, domain_slice in mask_values_to_slices_mapping
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)
true_domain = embedded_common.domain_intersection(t_broadcasted.domain, mask)
t_slices = () if true_domain.is_empty() else (t_broadcasted[true_domain],)

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
inverted_masks = _invert_domain(mask)
false_domains = tuple(
intersection
for d in inverted_masks
if not (
intersection := embedded_common.domain_intersection(f_broadcasted.domain, d)
).is_empty()
)
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]
f_slices = tuple(f_broadcasted[d] for d in false_domains)

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)
if len(t_slices) + len(f_slices) == 0:
# no data to concatenate, return an empty field
nd_array_class = _get_nd_array_class(true_field, false_field)
return _size0_field(nd_array_class, dims=t_broadcasted.domain.dims, dtype=true_field.dtype)
return _concat(*f_slices, *t_slices, dim=mask_dim)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type]


def _make_reduction(
Expand Down
38 changes: 19 additions & 19 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,25 @@ def concat_where(
false_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
/,
) -> common.Field | Tuple:
"""
Concatenates two field fields based on a 1D mask.

The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields.
Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain.

TODO(havogt): I can't get this doctest to run, even after copying the __doc__ in the decorator
Example:
>>> I = common.Dimension("I")
>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2], domain={I: (0, 2)})
>>> false_field = common._field([3, 4, 5], domain={I: (1, 4)})
>>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)})

>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2, 3], domain={I: (0, 3)})
>>> false_field = common._field(
... [4], domain={I: (2, 3)}
... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values
"""Assemble a field by selecting from ``true_field`` where ``cond`` applies and from ``false_field`` elsewhere.

Unlike ``where`` (element-wise selection via a boolean mask field), ``concat_where``
works on **domain regions**: the condition is a ``Domain`` (not a ``Field``), and the
result is the concatenation of slices from the two fields along one dimension.
Each field only needs to cover its own region — they may be non-overlapping.

The condition must be a 1D ``Domain`` (e.g. ``I < 5``).

Args:
cond: 1D Domain specifying the "true" region.
true_field: Field (or scalar) providing values inside the domain region.
false_field: Field (or scalar) providing values outside the domain region.

Returns:
A new field whose domain is the concatenation of the contributed regions.

Raises:
NonContiguousDomain: If the resulting domain has interior gaps.
"""
raise NotImplementedError()

Expand Down
5 changes: 3 additions & 2 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_PROGRAM_METRICS = "uses_program_metrics"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
USES_CONCAT_WHERE = "uses_concat_where"
EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN = "embedded_concat_where_infinite_domain"
USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS = "uses_program_with_sliced_out_arguments"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"

Expand Down Expand Up @@ -167,7 +168,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
XFAIL,
UNSUPPORTED_MESSAGE,
), # we can't extract the field type from scan args
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
(EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE),
]
JAX_EMBEDDED_SKIP_LIST = EMBEDDED_SKIP_LIST + [
(USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS, XFAIL, UNSUPPORTED_MESSAGE),
Expand All @@ -179,7 +180,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
(EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE),
]
GTFN_SKIP_TEST_LIST = (
COMMON_SKIP_TEST_LIST
Expand Down
Loading
Loading