diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 636ed0a4da..c154bdfbda 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -327,7 +327,7 @@ def __init__(self, builder: StencilBuilder, debug_stree: bool = False) -> None: self.builder = builder self.debug_stree = debug_stree - def schedule_tree(self) -> tn.ScheduleTreeRoot: + def schedule_tree(self, *, validate: bool = False) -> tn.ScheduleTreeRoot: """ Schedule tree representation of the gtir (taken from the builder). @@ -358,8 +358,13 @@ def schedule_tree(self) -> tn.ScheduleTreeRoot: oir = oir_pipeline.run(oir) tir = OIRToTreeIR(self.builder).visit(oir) + stree = TreeIRToScheduleTree().visit(tir) - return TreeIRToScheduleTree().visit(tir) + if validate: + tn.validate_children_and_parents_align(stree) + tn.validate_has_no_other_node_types(stree) + + return stree @staticmethod def _strip_history(sdfg: SDFG) -> None: @@ -403,7 +408,7 @@ def sdfg_via_schedule_tree(self, *, validate: bool = False, simplify: bool = Tru # - we expect all 3-loops to be singled maps/for # - we expect the layout to be K-JI # - we expect all non-cartesian control flow to be innermost - stree = self.schedule_tree() + stree = self.schedule_tree(validate=validate) # Re-order cartesian loops to match loops with memory layout # - layout is _always_ given I-J-K @@ -423,6 +428,10 @@ def sdfg_via_schedule_tree(self, *, validate: bool = False, simplify: bool = Tru flipper = passes.PushVerticalMapDown() flipper.visit(stree) + if validate: + tn.validate_children_and_parents_align(stree) + tn.validate_has_no_other_node_types(stree) + # Create SDFG sdfg = stree.as_sdfg( validate=validate, diff --git a/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py b/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py index 5ebc914acc..55b44994bb 100644 --- a/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py +++ b/src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from copy import deepcopy + from dace.sdfg.analysis.schedule_tree import treenodes as tn from gt4py.cartesian.gtc.dace.passes import utils @@ -41,39 +43,28 @@ class PushVerticalMapDown(tn.ScheduleNodeVisitor): // computation here (2) """ - def _push_K_loop_in_IJ(self, node: tn.MapScope | tn.ForScope): + def visit_MapScope(self, scope: tn.MapScope) -> None: + if not scope.node.map.params[0].startswith("__k"): + return + # take refs before moving things around - parent = node - grandparent = node.parent - grandparent_children = node.parent.children - k_loop_index = utils.list_index(grandparent_children, parent) + parent = scope.parent + parent_children = scope.parent.children + k_loop_index = utils.list_index(parent_children, scope) - for child in node.children: + for child in scope.children: if not isinstance(child, tn.MapScope): raise NotImplementedError("We don't expect anything else than (IJ)-MapScopes here.") - # New loop with MapEntry (`node`) from parent and children from `child` - if isinstance(node, tn.MapScope): - new_loop = tn.MapScope(node=parent.node, children=child.children) - new_loop.parent = child - elif isinstance(node, tn.ForScope): - new_loop = node - node.children = child.children - node.parent = child - else: - raise ValueError(f"Unknown node of type {type(node)}") - child.children = [new_loop] - child.parent = grandparent - grandparent_children.insert(k_loop_index, child) + child.children = [ + # New loop with MapEntry (`node`) from parent and children from `child` + tn.MapScope( + node=deepcopy(scope.node), children=[c for c in child.children], parent=child + ) + ] + child.parent = parent + parent_children.insert(k_loop_index, child) k_loop_index += 1 # delete old (now unused) node - grandparent_children.remove(node) - - def visit_MapScope(self, node: tn.MapScope): - if node.node.map.params[0].startswith("__k"): - self._push_K_loop_in_IJ(node) - - def visit_ForScope(self, node: tn.ForScope): - if node.loop.loop_variable.startswith("__k"): - self._push_K_loop_in_IJ(node) + parent_children.remove(scope) diff --git a/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py b/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py index 6a2e1ca4a0..403be8909f 100644 --- a/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py +++ b/src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py @@ -29,17 +29,17 @@ class SwapHorizontalMaps(tn.ScheduleNodeVisitor): // computation here """ - def visit_MapScope(self, node: tn.MapScope): - if node.node.params[0].startswith(Axis.J.iteration_symbol()) and node.node.params[ - 1 - ].startswith(Axis.I.iteration_symbol()): + def visit_MapScope(self, node: tn.MapScope) -> None: + params = node.node.map.params + first_param_J = params[0].startswith(Axis.J.iteration_symbol()) + + if first_param_J and params[1].startswith(Axis.I.iteration_symbol()): # Swap params - tmp_index = node.node.params[0] - node.node.params[0] = node.node.params[1] - node.node.params[1] = tmp_index + param_J = params[0] + params[0] = params[1] + params[1] = param_J + # Swap ranges - tmp_bounds = node.node.range[0] - node.node.range[0] = node.node.range[1] - node.node.range[1] = tmp_bounds + node.node.map.range.reorder([1, 0]) self.visit(node.children) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/__init__.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/__init__.py similarity index 100% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/__init__.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/__init__.py diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py similarity index 67% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py index 9f0553b4e8..970f9bdd04 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_push_vertical_map_down.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_push_vertical_map_down.py @@ -9,6 +9,8 @@ import pytest from dace import nodes, subsets +from dace.properties import CodeBlock +from dace.sdfg.state import LoopRegion from dace.sdfg.analysis.schedule_tree import treenodes as tn from gt4py.cartesian.gtc.dace.passes import PushVerticalMapDown @@ -40,6 +42,7 @@ def test_push_vertical_map_down(): flipper = PushVerticalMapDown() flipper.visit(root) + # Assert that the k-loop has been pushed inside the ij-loop. assert len(root.children) == 1 assert isinstance(root.children[0], tn.MapScope) assert root.children[0].node.map.params == ["__i", "__j"] @@ -49,6 +52,42 @@ def test_push_vertical_map_down(): assert root.children[0].children[0].node.map.params == ["__k"] +def test_push_vertical_map_down_for_scope(): + root = tn.ScheduleTreeRoot(name="tester", children=[]) + k_loop = tn.ForScope( + loop=LoopRegion( + "vertical map", + loop_var="__k", + initialize_expr=CodeBlock("__k = 0"), + condition_expr=CodeBlock("__k < 10"), + update_expr=CodeBlock("__k += 1"), + ), + children=[], + ) + k_loop.parent = root + ij_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map("horizontal maps", ["__i", "__j"], subsets.Range.from_string("0:5,0:8")) + ), + children=[], + ) + ij_loop.parent = k_loop + k_loop.children.append(ij_loop) + root.children.append(k_loop) + + flipper = PushVerticalMapDown() + flipper.visit(root) + + # Assert that the tree hasn't changed. + assert len(root.children) == 1 + assert isinstance(root.children[0], tn.ForScope) + assert root.children[0].loop.loop_variable == "__k" + + assert len(root.children[0].children) == 1 + assert isinstance(root.children[0].children[0], tn.MapScope) + assert root.children[0].children[0].node.map.params == ["__i", "__j"] + + def test_push_vertical_map_down_multiple_horizontal_maps(): root = tn.ScheduleTreeRoot(name="tester", children=[]) k_loop = tn.MapScope( @@ -86,6 +125,8 @@ def test_push_vertical_map_down_multiple_horizontal_maps(): flipper = PushVerticalMapDown() flipper.visit(root) + # Assert that the k-loop has been pushed inside the ij-loops, + # effectively duplicating the k-loop. assert len(root.children) == 2 for index, child in enumerate(root.children): diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py new file mode 100644 index 0000000000..38085b7c3d --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_swap_horizontal_maps.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from dace import nodes, subsets +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from gt4py.cartesian.gtc.dace.passes import SwapHorizontalMaps + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable adds the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_swap_horizontal_maps() -> None: + root = tn.ScheduleTreeRoot(name="tester", children=[]) + k_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map("vertical map", ["__k"], subsets.Range.from_string("0:10")) + ), + children=[], + ) + k_loop.parent = root + ji_loop = tn.MapScope( + node=nodes.MapEntry( + map=nodes.Map("horizontal maps", ["__j", "__i"], subsets.Range.from_string("0:5,0:8")) + ), + children=[], + ) + ji_loop.parent = k_loop + k_loop.children.append(ji_loop) + root.children.append(k_loop) + + flipper = SwapHorizontalMaps() + flipper.visit(root) + + horizontal_maps = ji_loop.node.map + assert horizontal_maps.params[0] == "__i" + assert horizontal_maps.range[0] == (0, 7, 1) + assert horizontal_maps.params[1] == "__j" + assert horizontal_maps.range[1] == (0, 4, 1) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_utils.py similarity index 100% rename from tests/cartesian_tests/unit_tests/test_gtc/dace/dace/test_utils.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/passes/test_utils.py