diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 636ed0a4da..c6289587a7 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -267,16 +267,10 @@ def freeze_origin_domain_sdfg( wrapper_sdfg = SDFG("frozen_" + inner_sdfg.name) state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state") - inputs = set() - outputs = set() - for node, parent in inner_sdfg.all_nodes_recursive(): - if not isinstance(node, nodes.AccessNode) or inner_sdfg.arrays[node.data].transient: - continue - - if node.has_reads(parent): - inputs.add(node.data) - if node.has_writes(parent): - outputs.add(node.data) + # gather inputs & outputs (i.e. reads/writes without transients) + inputs, outputs = inner_sdfg.read_and_write_sets() + inputs = set(filter(lambda name: not inner_sdfg.arrays[name].transient, inputs)) + outputs = set(filter(lambda name: not inner_sdfg.arrays[name].transient, outputs)) # fake DebugInfo to avoid calls to `inspect` nsdfg = state.add_nested_sdfg(inner_sdfg, inputs, outputs, debuginfo=DebugInfo(123456)) diff --git a/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py b/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py index 5ebc914acc..9ed3a13578 100644 --- a/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py +++ b/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from copy import deepcopy + from dace.sdfg.analysis.schedule_tree import treenodes as tn from gt4py.cartesian.gtc.dace.passes import utils @@ -54,14 +56,17 @@ def _push_K_loop_in_IJ(self, node: tn.MapScope | tn.ForScope): # New loop with MapEntry (`node`) from parent and children from `child` if isinstance(node, tn.MapScope): - new_loop = tn.MapScope(node=parent.node, children=child.children) - new_loop.parent = child - elif isinstance(node, tn.ForScope): - new_loop = node - node.children = child.children - node.parent = child + new_loop = tn.MapScope( + node=deepcopy(parent.node), children=[c for c in child.children], parent=child + ) else: - raise ValueError(f"Unknown node of type {type(node)}") + assert isinstance(node, tn.ForScope) + new_loop = tn.ForScope( + loop=deepcopy(parent.loop), + children=[c for c in child.children], + parent=child, + ) + child.children = [new_loop] child.parent = grandparent grandparent_children.insert(k_loop_index, child) diff --git a/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py b/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py index 6a2e1ca4a0..403be8909f 100644 --- a/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py +++ b/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py @@ -29,17 +29,17 @@ class SwapHorizontalMaps(tn.ScheduleNodeVisitor): // computation here """ - def visit_MapScope(self, node: tn.MapScope): - if node.node.params[0].startswith(Axis.J.iteration_symbol()) and node.node.params[ - 1 - ].startswith(Axis.I.iteration_symbol()): + def visit_MapScope(self, node: tn.MapScope) -> None: + params = node.node.map.params + first_param_J = params[0].startswith(Axis.J.iteration_symbol()) + + if first_param_J and params[1].startswith(Axis.I.iteration_symbol()): # Swap params - tmp_index = node.node.params[0] - node.node.params[0] = node.node.params[1] - node.node.params[1] = tmp_index + param_J = params[0] + params[0] = params[1] + params[1] = param_J + # Swap ranges - tmp_bounds = node.node.range[0] - node.node.range[0] = node.node.range[1] - node.node.range[1] = tmp_bounds + node.node.map.range.reorder([1, 0]) self.visit(node.children) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/__init__.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/__init__.py similarity index 100% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/__init__.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/__init__.py diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py similarity index 70% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py index 9f0553b4e8..cd4a03ce2a 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py @@ -9,6 +9,8 @@ import pytest from dace import nodes, subsets +from dace.properties import CodeBlock +from dace.sdfg.state import LoopRegion from dace.sdfg.analysis.schedule_tree import treenodes as tn from gt4py.cartesian.gtc.dace.passes import PushVerticalMapDown @@ -49,6 +51,43 @@ def test_push_vertical_map_down(): assert root.children[0].children[0].node.map.params == ["__k"] +def test_push_vertical_map_down_for_scope(): + root = tn.ScheduleTreeRoot(name="tester", children=[]) + k_loop = tn.ForScope( + loop=LoopRegion( + "vertical map", + loop_var="__k", + initialize_expr=CodeBlock("__k = 0"), + condition_expr=CodeBlock("__k < 10"), + update_expr=CodeBlock("__k += 1"), + ), + children=[], + ) + k_loop.parent = root + ij_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map("horizontal maps", ["__i", "__j"], subsets.Range.from_string("0:5,0:8")) + ), + children=[], + ) + ij_loop.parent = k_loop + k_loop.children.append(ij_loop) + root.children.append(k_loop) + + flipper = PushVerticalMapDown() + flipper.visit(root) + + tn.validate_children_and_parents_align(root) + + assert len(root.children) == 1 + assert isinstance(root.children[0], tn.MapScope) + assert root.children[0].node.map.params == ["__i", "__j"] + + assert len(root.children[0].children) == 1 + assert isinstance(root.children[0].children[0], tn.ForScope) + assert root.children[0].children[0].loop.loop_variable == "__k" + + def test_push_vertical_map_down_multiple_horizontal_maps(): root = tn.ScheduleTreeRoot(name="tester", children=[]) k_loop = tn.MapScope( diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py new file mode 100644 index 0000000000..1246fd5a2f --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py @@ -0,0 +1,51 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from dace import nodes, subsets +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from gt4py.cartesian.gtc.dace.passes import SwapHorizontalMaps + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable adds the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_swap_horizontal_maps() -> None: + root = tn.ScheduleTreeRoot(name="tester", children=[]) + k_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map("vertical map", ["__k"], subsets.Range.from_string("0:10:2")) + ), + children=[], + ) + k_loop.parent = root + ji_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map( + "horizontal maps", ["__j", "__i"], subsets.Range([(0, 4, 1), (0, 7, 2, 2)]) + ) + ), + children=[], + ) + ji_loop.parent = k_loop + k_loop.children.append(ji_loop) + root.children.append(k_loop) + + flipper = SwapHorizontalMaps() + flipper.visit(root) + + horizontal_maps = ji_loop.node.map + assert horizontal_maps.params[0] == "__i" + assert horizontal_maps.range[0] == (0, 7, 2) + assert horizontal_maps.range.tile_sizes[0] == 2 + assert horizontal_maps.params[1] == "__j" + assert horizontal_maps.range[1] == (0, 4, 1) + assert horizontal_maps.range.tile_sizes[1] == 1 diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_utils.py similarity index 100% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_utils.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_utils.py diff --git a/uv.lock b/uv.lock index 4ab7296e93..a02b7efd02 100644 --- a/uv.lock +++ b/uv.lock @@ -1177,7 +1177,7 @@ wheels = [ [[package]] name = "dace" version = "1.0.0" -source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#1fb397865e89c6b8907c4de0cded046e153b48ac" } +source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#0d9f3b4ede7a87aa3c86913481740390431e2b21" } resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", @@ -1739,7 +1739,7 @@ build = [ { name = "wheel" }, ] dace-cartesian = [ - { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#1fb397865e89c6b8907c4de0cded046e153b48ac" } }, + { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#0d9f3b4ede7a87aa3c86913481740390431e2b21" } }, ] dace-next = [ { name = "dace", version = "43!2026.2.12", source = { registry = "https://gridtools.github.io/pypi/" } },