diff --git a/pyproject.toml b/pyproject.toml index 08d5e7a1a0..be909de9e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -318,6 +318,8 @@ markers = [ 'uses_max_over: tests that use the max_over builtin', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'uses_concat_where: tests that use the concat_where builtin', + 'embedded_concat_where_infinite_domain: tests with concat_where resulting in an infinite domain', + 'embedded_concat_where_non_contiguous_domain: tests with concat_where on non-contiguous domains', 'uses_program_metrics: tests that require backend support for program metrics', 'uses_program_with_sliced_out_arguments: tests that use a sliced argument which is not supported for non-mutable arrays, e.g. JAX', 'checks_specific_error: tests that rely on the backend to produce a specific error message' diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e9aff84a15..2948c7fc8b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -18,16 +18,7 @@ from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import ( - ClassVar, - Iterable, - Never, - Optional, - ParamSpec, - TypeAlias, - TypeVar, - cast, -) +from gt4py.eve.extended_typing import ClassVar, Never, Optional, ParamSpec, TypeAlias, TypeVar, cast from gt4py.next import common, utils from gt4py.next.embedded import ( common as embedded_common, @@ -820,39 +811,6 @@ def _hyperslice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_slices( - mask: core_defs.NDArrayObject, -) -> list[tuple[bool, slice]]: - """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" - # TODO: does it make sense to upgrade this naive algorithm to numpy? - assert mask.ndim == 1 - cur = bool(mask[0].item()) - ind = 0 - res = [] - for i in range(1, mask.shape[0]): - # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy - if (mask_i := bool(mask[i].item())) != cur: - res.append((cur, slice(ind, i))) - cur = mask_i - ind = i - res.append((cur, slice(ind, mask.shape[0]))) - return res - - -def _trim_empty_domains( - lst: Iterable[tuple[bool, common.Domain]], -) -> list[tuple[bool, common.Domain]]: - """Remove empty domains from beginning and end of the list.""" - lst = list(lst) - if not lst: - return lst - if lst[0][1].is_empty(): - return _trim_empty_domains(lst[1:]) - if lst[-1][1].is_empty(): - return _trim_empty_domains(lst[:-1]) - return lst - - def _to_field( value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField] ) -> common.Field: @@ -906,85 +864,95 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: # TODO(havogt): this function could be extended to a general concat - # currently only concatenate along the given dimension and requires the fields to be ordered + # currently only concatenate along the given dimension + sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start) if ( - len(fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() + len(sorted_fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") - nd_array_class = _get_nd_array_class(*fields) + nd_array_class = _get_nd_array_class(*sorted_fields) return nd_array_class.from_array( nd_array_class.array_ns.concatenate( - [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + [ + nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) + for f in sorted_fields + ], axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) +def _invert_domain(domain: common.Domain) -> tuple[common.Domain, ...]: + assert domain.ndim == 1 + dim = domain.dims[0] + rng = domain.ranges[0] + + result = [] + if rng.start is not common.Infinity.NEGATIVE: + result.append( + common.Domain( + dims=(dim,), ranges=(common.UnitRange(common.Infinity.NEGATIVE, rng.start),) + ) + ) + if rng.stop is not common.Infinity.POSITIVE: + result.append( + common.Domain( + dims=(dim,), ranges=(common.UnitRange(rng.stop, common.Infinity.POSITIVE),) + ) + ) + return tuple(result) + + +def _size0_field( + nd_array_class: type[NdArrayField], dims: tuple[common.Dimension, ...], dtype: core_defs.DType +) -> NdArrayField: + return nd_array_class.from_array( + nd_array_class.array_ns.empty((0,) * len(dims), dtype=dtype.scalar_type), + domain=common.Domain(dims=dims, ranges=(common.UnitRange(0, 0),) * len(dims)), + ) + + def _concat_where( - mask_field: common.Field, true_field: common.Field, false_field: common.Field + mask: common.Domain, + true_field: common.Field, + false_field: common.Field, ) -> common.Field: - cls_ = _get_nd_array_class(mask_field, true_field, false_field) - xp = cls_.array_ns - if mask_field.domain.ndim != 1: + if mask.ndim != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = mask_field.domain.dims[0] + mask_dim = mask.dims[0] # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils - # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( - mask_field.ndarray - ) - mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( - (mask, mask_field.domain.slice_at[domain_slice]) - for mask, domain_slice in mask_values_to_slices_mapping - ) - # mask domains intersected with the respective fields - mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( - ( - mask_value, - embedded_common.domain_intersection( - t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain - ), - ) - for mask_value, mask_domain in mask_values_to_domain_mapping - ) + true_domain = embedded_common.domain_intersection(t_broadcasted.domain, mask) + t_slices = () if true_domain.is_empty() else (t_broadcasted[true_domain],) - # remove the empty domains from the beginning and end - mask_values_to_intersected_domains_mapping = _trim_empty_domains( - mask_values_to_intersected_domains_mapping + inverted_masks = _invert_domain(mask) + false_domains = tuple( + intersection + for d in inverted_masks + if not ( + intersection := embedded_common.domain_intersection(f_broadcasted.domain, d) + ).is_empty() ) - if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): - raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." - ) - - # slice the fields with the domain ranges - transformed = [ - t_broadcasted[d] if v else f_broadcasted[d] - for v, d in mask_values_to_intersected_domains_mapping - ] + f_slices = tuple(f_broadcasted[d] for d in false_domains) - # stack the fields together - if transformed: - return _concat(*transformed, dim=mask_dim) - else: - result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) - result_array = xp.empty(result_domain.shape) - return cls_.from_array(result_array, domain=result_domain) + if len(t_slices) + len(f_slices) == 0: + # no data to concatenate, return an empty field + nd_array_class = _get_nd_array_class(true_field, false_field) + return _size0_field(nd_array_class, dims=t_broadcasted.domain.dims, dtype=true_field.dtype) + return _concat(*f_slices, *t_slices, dim=mask_dim) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index e3527e3add..25d915a423 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -25,25 +25,25 @@ def concat_where( false_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection, /, ) -> common.Field | Tuple: - """ - Concatenates two field fields based on a 1D mask. - - The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. - Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. - - TODO(havogt): I can't get this doctest to run, even after copying the __doc__ in the decorator - Example: - >>> I = common.Dimension("I") - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2], domain={I: (0, 2)}) - >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) - >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) - - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) - >>> false_field = common._field( - ... [4], domain={I: (2, 3)} - ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values + """Assemble a field by selecting from ``true_field`` where ``cond`` applies and from ``false_field`` elsewhere. + + Unlike ``where`` (element-wise selection via a boolean mask field), ``concat_where`` + works on **domain regions**: the condition is a ``Domain`` (not a ``Field``), and the + result is the concatenation of slices from the two fields along one dimension. + Each field only needs to cover its own region — they may be non-overlapping. + + The condition must be a 1D ``Domain`` (e.g. ``I < 5``). + + Args: + cond: 1D Domain specifying the "true" region. + true_field: Field (or scalar) providing values inside the domain region. + false_field: Field (or scalar) providing values outside the domain region. + + Returns: + A new field whose domain is the concatenation of the contributed regions. + + Raises: + NonContiguousDomain: If the resulting domain has interior gaps. """ raise NotImplementedError() diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 3396d93d3c..c7a5186ce6 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -125,6 +125,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_PROGRAM_METRICS = "uses_program_metrics" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" USES_CONCAT_WHERE = "uses_concat_where" +EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN = "embedded_concat_where_infinite_domain" +EMBEDDED_CONCAT_WHERE_NON_CONTIGUOUS_DOMAIN = "embedded_concat_where_non_contiguous_domain" USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS = "uses_program_with_sliced_out_arguments" CHECKS_SPECIFIC_ERROR = "checks_specific_error" @@ -167,7 +169,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), + (EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE), + (EMBEDDED_CONCAT_WHERE_NON_CONTIGUOUS_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE), ] JAX_EMBEDDED_SKIP_LIST = EMBEDDED_SKIP_LIST + [ (USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS, XFAIL, UNSUPPORTED_MESSAGE), @@ -178,9 +181,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] -GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [ - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), -] +GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST + DOMAIN_INFERENCE_SKIP_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 0ec8608a37..0c2e1839fa 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -78,6 +78,7 @@ def testee(a: cases.IJKField, b: cases.IJKField, N: np.int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, N, out=out, ref=a.asnumpy()) +@pytest.mark.embedded_concat_where_infinite_domain def test_concat_where_scalar_broadcast(cartesian_case): @gtx.field_operator def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: @@ -97,6 +98,7 @@ def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, cartesian_case.default_sizes[KDim], out=out, ref=ref) +@pytest.mark.embedded_concat_where_infinite_domain def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): """Output domain such that the scalar branch is never active.""" @@ -253,6 +255,7 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.embedded_concat_where_non_contiguous_domain def test_dimension_two_conditions_or(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: @@ -272,11 +275,19 @@ def test_lap_like(cartesian_case): def testee( inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] ) -> cases.IJField: - # TODO add support for multi-dimensional concat_where masks + # TODO(havogt) add support for multi-dimensional concat_where and non-contigous unions return concat_where( - (IDim == 0) | (IDim == shape[0] - 1), + (IDim == 0), boundary, - concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, inp), + concat_where( + IDim == shape[0] - 1, + boundary, + concat_where( + JDim == 0, + boundary, + concat_where(JDim == shape[1] - 1, boundary, inp), + ), + ), ) out = cases.allocate(cartesian_case, testee, cases.RETURN)() diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 871c538807..5c74590204 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -1104,91 +1104,44 @@ def test_hyperslice(index_array, expected): @pytest.mark.uses_concat_where @pytest.mark.parametrize( - "mask_data, true_data, false_data, expected", + "cond, true_data, false_data, expected", [ + (D0 == 0, ([0, 0], None), ([1, 1], None), ([0, 1], None)), + (D0 == -1, ([0, 0], {D0: (-1, 1)}), ([1, 1], {D0: (0, 2)}), ([0, 1, 1], {D0: (-1, 2)})), + (D0 < 0, ([0, 0], {D0: (-2, 0)}), ([1, 1], {D0: (0, 2)}), ([0, 0, 1, 1], {D0: (-2, 2)})), + (D0 == 1, ([0, 0, 0], None), ([1, 1, 1], None), ([1, 0, 1], None)), + # non-contiguous domain + (D0 <= 0, ([0, 0], {D0: (-2, 0)}), ([1, 1], {D0: (0, 2)}), None), + # empty result domain ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], None), - ([6, 7, 8, 9, 10], None), - ([1, 7, 3, 9, 5], None), - ), - ( - ([True, False, True, False], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9], {D0: (1, 5)}), - ([3, 6, 5, 8], {D0: (0, 4)}), - ), - ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9, 10], {D0: (1, 6)}), - ([3, 6, 5, 8], {D0: (0, 4)}), - ), - ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9, 10], {D0: (2, 7)}), - None, - ), - ( - # empty result domain - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-5, 0)}), - ([6, 7, 8, 9, 10], {D0: (5, 10)}), + D0 < 0, + ([0, 0], {D0: (0, 2)}), + ([1, 1], {D0: (-2, 0)}), ([], {D0: (0, 0)}), ), - ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-4, 1)}), - ([6, 7, 8, 9, 10], {D0: (5, 10)}), - ([5], {D0: (0, 1)}), - ), - ( - # broadcasting true_field - ([True, False, True, False, True], {D0: 5}), - ([1, 2, 3, 4, 5], {D0: 5}), - ([[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], {D0: 5, D1: 2}), - ([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]], {D0: 5, D1: 2}), - ), - ( - ([True, False, True, False, True], None), - (42, None), - ([6, 7, 8, 9, 10], None), - ([42, 7, 42, 9, 42], None), - ), - ( - # parts of mask_ranges are concatenated - ([True, True, False, False], None), - ([1, 2], {D0: (1, 3)}), - ([3, 4], {D0: (1, 3)}), - ([1, 4], {D0: (1, 3)}), - ), - ( - # parts of mask_ranges are concatenated and yield non-contiguous domain - ([True, False, True, False], None), - ([1, 2], {D0: (0, 2)}), - ([3, 4], {D0: (2, 4)}), - None, + # broadcasting from scalar (needs infinite domain support) + pytest.param( + D0 == 0, + ([0, 0], None), + (1, None), + ([0, 1], None), + marks=[ + pytest.mark.embedded_concat_where_infinite_domain, + pytest.mark.xfail(reason="requires infinite domain support"), + ], ), ], ) def test_concat_where( nd_array_implementation, - mask_data: tuple[list[bool], Optional[common.DomainLike]], + cond: common.Domain, true_data: tuple[list[int], Optional[common.DomainLike]], false_data: tuple[list[int], Optional[common.DomainLike]], expected: Optional[tuple[list[int], Optional[common.DomainLike]]], ): - mask_lst, mask_domain = mask_data true_lst, true_domain = true_data false_lst, false_domain = false_data - mask_field = _make_field_or_scalar( - mask_lst, - nd_array_implementation=nd_array_implementation, - domain=common.domain(mask_domain) if mask_domain is not None else None, - dtype=bool, - ) true_field = _make_field_or_scalar( true_lst, nd_array_implementation=nd_array_implementation, @@ -1204,7 +1157,7 @@ def test_concat_where( if expected is None: with pytest.raises(embedded_exceptions.NonContiguousDomain): - nd_array_field._concat_where(mask_field, true_field, false_field) + nd_array_field._concat_where(cond, true_field, false_field) else: expected_lst, expected_domain_like = expected expected_array = np.asarray(expected_lst) @@ -1214,7 +1167,133 @@ def test_concat_where( else _make_default_domain(expected_array.shape) ) - result = nd_array_field._concat_where(mask_field, true_field, false_field) + result = nd_array_field._concat_where(cond, true_field, false_field) assert expected_domain == result.domain np.testing.assert_allclose(result.asnumpy(), expected_array) + + +@pytest.mark.parametrize( + "domain, expected", + [ + # finite domain → two complement regions (left and right of the domain) + ( + common.Domain(dims=(D0,), ranges=(UnitRange(2, 5),)), + ( + common.Domain(dims=(D0,), ranges=(UnitRange(common.Infinity.NEGATIVE, 2),)), + common.Domain(dims=(D0,), ranges=(UnitRange(5, common.Infinity.POSITIVE),)), + ), + ), + # single-point domain + ( + D0 == 3, + ( + common.Domain(dims=(D0,), ranges=(UnitRange(common.Infinity.NEGATIVE, 3),)), + common.Domain(dims=(D0,), ranges=(UnitRange(4, common.Infinity.POSITIVE),)), + ), + ), + # open on the left (D0 < 5) → only a right complement + ( + common.Domain(dims=(D0,), ranges=(UnitRange(common.Infinity.NEGATIVE, 5),)), + (common.Domain(dims=(D0,), ranges=(UnitRange(5, common.Infinity.POSITIVE),)),), + ), + # open on the right (D0 >= 5) → only a left complement + ( + common.Domain(dims=(D0,), ranges=(UnitRange(5, common.Infinity.POSITIVE),)), + (common.Domain(dims=(D0,), ranges=(UnitRange(common.Infinity.NEGATIVE, 5),)),), + ), + # full line (both infinite) → empty complement + ( + common.Domain( + dims=(D0,), + ranges=(UnitRange(common.Infinity.NEGATIVE, common.Infinity.POSITIVE),), + ), + (), + ), + # empty domain [3, 3) → normalized to [0, 0), invert sees (0, 0) + ( + common.Domain(dims=(D0,), ranges=(UnitRange(3, 3),)), + ( + common.Domain(dims=(D0,), ranges=(UnitRange(common.Infinity.NEGATIVE, 0),)), + common.Domain(dims=(D0,), ranges=(UnitRange(0, common.Infinity.POSITIVE),)), + ), + ), + ], +) +def test_invert_domain(domain, expected): + result = nd_array_field._invert_domain(domain) + assert result == expected + + +@pytest.mark.uses_concat_where +@pytest.mark.parametrize( + "fields_data, dim, expected_data, expect_error", + [ + # two adjacent fields, already ordered + ( + [([1, 2], {D0: (0, 2)}), ([3, 4], {D0: (2, 4)})], + D0, + ([1, 2, 3, 4], {D0: (0, 4)}), + None, + ), + # two adjacent fields, reverse order → _concat sorts them + ( + [([3, 4], {D0: (2, 4)}), ([1, 2], {D0: (0, 2)})], + D0, + ([1, 2, 3, 4], {D0: (0, 4)}), + None, + ), + # three fields + ( + [([1], {D0: (0, 1)}), ([2], {D0: (1, 2)}), ([3], {D0: (2, 3)})], + D0, + ([1, 2, 3], {D0: (0, 3)}), + None, + ), + # single field (trivial concat) + ( + [([10, 20, 30], {D0: (5, 8)})], + D0, + ([10, 20, 30], {D0: (5, 8)}), + None, + ), + # gap between fields → NonContiguousDomain + ( + [([1, 2], {D0: (0, 2)}), ([3, 4], {D0: (3, 5)})], + D0, + None, + embedded_exceptions.NonContiguousDomain, + ), + # overlapping fields → ValueError + ( + [([1, 2, 3], {D0: (0, 3)}), ([4, 5, 6], {D0: (2, 5)})], + D0, + None, + ValueError, + ), + # negative domain indices + ( + [([1, 2], {D0: (-3, -1)}), ([3], {D0: (-1, 0)})], + D0, + ([1, 2, 3], {D0: (-3, 0)}), + None, + ), + ], +) +def test_concat(fields_data, dim, expected_data, expect_error): + fields = [ + common._field(np.asarray(data, dtype=np.int32), domain=common.domain(domain)) + for data, domain in fields_data + ] + + if expect_error is not None: + with pytest.raises(expect_error): + nd_array_field._concat(*fields, dim=dim) + else: + expected_array = np.asarray(expected_data[0], dtype=np.int32) + expected_domain = common.domain(expected_data[1]) + + result = nd_array_field._concat(*fields, dim=dim) + + assert result.domain == expected_domain + np.testing.assert_allclose(result.asnumpy(), expected_array)