From 16d9bc1726a972f1b2f179b4c969cacf0fe62020 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 24 Mar 2026 14:50:27 +0100 Subject: [PATCH 1/2] fix[cartesian]: Support for data dims of size one in dace backends This PR fixes an issue in the dace backends. Access of data dimensions with size one would generate invalid SDFGs. In case an array has size one, there's only one value, so the only valid way to access array `a` is `a[0]`. When accessed inside a `Tasklet`, this access is already fully defined in the Memlet. --- .../cartesian/gtc/dace/oir_to_tasklet.py | 23 ++++++-- .../feature_tests/test_call_interface.py | 52 +++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py index cf83476b13..2329128d70 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import operator +import warnings from dataclasses import dataclass from functools import reduce from typing import Any, Final @@ -101,9 +102,26 @@ def visit_FieldAccess(self, node: oir.FieldAccess, ctx: Context, is_target: bool # Gather all parts of the variable name in this list name_parts = [tasklet_name] + # Domain sizes of data dimensions + data_domains: list[int] = ( + ctx.tree.containers[node.name].shape[-len(node.data_index) :] if node.data_index else [] + ) + # Data dimension subscript data_indices: list[str] = [] - for index in node.data_index: + for dimension, index in enumerate(node.data_index): + # special case: domain size of this data dimension is one + if data_domains[dimension] == 1: + if not isinstance(index, oir.Literal): + warnings.warn( + f"Data dimension {dimension} of field {node.name} has static size one. Accessing without a Literal is suspicious.", + stacklevel=2, + ) + + # no need for data dimension subscript because that access + # is entirely captured in the memlet. + continue + data_indices.append(self.visit(index, ctx=ctx, is_target=False)) if isinstance(node.offset, oir.AbsoluteKIndex): @@ -139,9 +157,6 @@ def visit_FieldAccess(self, node: oir.FieldAccess, ctx: Context, is_target: bool return "".join(filter(None, name_parts)) # Build Memlet and add it to inputs/outputs - data_domains: list[int] = ( - ctx.tree.containers[node.name].shape[-len(node.data_index) :] if node.data_index else [] - ) memlet = Memlet( data=node.name, subset=_memlet_subset(node, data_domains, ctx), diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py index 75f1d71159..5f551672c9 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py @@ -11,6 +11,7 @@ import numpy as np import pytest +from gt4py.cartesian.frontend.exceptions import GTScriptSyntaxError import gt4py.cartesian.gtscript as gtscript import gt4py.storage as gt_storage from gt4py.cartesian.gtscript import Field, K @@ -404,6 +405,57 @@ def test_mismatch(self, sample_stencil): ) ) + @pytest.mark.parametrize("backend", ALL_BACKENDS) + def test_data_dimension_1d(self, backend: str): + @gtscript.stencil(backend=backend) + def data_dimension_1d(field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))]): + with computation(FORWARD), interval(...): + field_out[0, 0][0] = 42.0 + + ones = gt_storage.ones( + shape=(2, 3), + dimensions=["I", "J"], + dtype=(np.float64, (1,)), + backend=backend, + aligned_index=(0, 0), + ) + data_dimension_1d(ones) + assert ones[0, 0][0] == 42.0 + + def test_data_dimension_1d_warning(self): + backend = "dace:cpu" + + # Expect parsing to emit a warning + with pytest.warns(UserWarning, match="Accessing without a Literal is suspicious."): + + @gtscript.stencil(backend=backend) + def data_dimension_1d_warning( + field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))], index: int + ): + with computation(FORWARD), interval(...): + # 1. warn about accessing an array of static size 1 with a variable index + # 2. add a runtime check and error out in case index != 0 + field_out[0, 0][index] = 42.0 + + ones = gt_storage.ones( + shape=(2, 3), + dimensions=["I", "J"], + dtype=(np.float64, (1,)), + backend=backend, + aligned_index=(0, 0), + ) + data_dimension_1d_warning(ones, index=0) + assert ones[0, 0][0] == 42.0 + + def test_data_dimensions_1d_error(self): + # Expect out of bounds write to be caught at stencil parse time + with pytest.raises(GTScriptSyntaxError, match="Data index out of bounds."): + + @gtscript.stencil(backend=self.backend) + def data_dimension_1d_error(field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))]): + with computation(FORWARD), interval(...): + field_out[0, 0][1] = 42.0 + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_origin_unchanged(backend): From a981062562c7ba77e7acfce92382d8fd2925faf0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 24 Mar 2026 15:53:11 +0100 Subject: [PATCH 2/2] fixup: require dace for dace-only test --- .../integration_tests/feature_tests/test_call_interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py index 5f551672c9..7fcea58f8b 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_call_interface.py @@ -422,6 +422,7 @@ def data_dimension_1d(field_out: gtscript.Field[gtscript.IJ, (np.float64, (1,))] data_dimension_1d(ones) assert ones[0, 0][0] == 42.0 + @pytest.mark.requires_dace def test_data_dimension_1d_warning(self): backend = "dace:cpu"