Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __init__(self, builder: StencilBuilder, debug_stree: bool = False) -> None:
self.builder = builder
self.debug_stree = debug_stree

def schedule_tree(self) -> tn.ScheduleTreeRoot:
def schedule_tree(self, *, validate: bool = False) -> tn.ScheduleTreeRoot:
"""
Schedule tree representation of the gtir (taken from the builder).

Expand Down Expand Up @@ -358,8 +358,13 @@ def schedule_tree(self) -> tn.ScheduleTreeRoot:
oir = oir_pipeline.run(oir)

tir = OIRToTreeIR(self.builder).visit(oir)
stree = TreeIRToScheduleTree().visit(tir)

return TreeIRToScheduleTree().visit(tir)
if validate:
tn.validate_children_and_parents_align(stree)
tn.validate_has_no_other_node_types(stree)

return stree

@staticmethod
def _strip_history(sdfg: SDFG) -> None:
Expand Down Expand Up @@ -403,7 +408,7 @@ def sdfg_via_schedule_tree(self, *, validate: bool = False, simplify: bool = Tru
# - we expect all 3-loops to be singled maps/for
# - we expect the layout to be K-JI
# - we expect all non-cartesian control flow to be innermost
stree = self.schedule_tree()
stree = self.schedule_tree(validate=validate)

# Re-order cartesian loops to match loops with memory layout
# - layout is _always_ given I-J-K
Expand All @@ -423,6 +428,10 @@ def sdfg_via_schedule_tree(self, *, validate: bool = False, simplify: bool = Tru
flipper = passes.PushVerticalMapDown()
flipper.visit(stree)

if validate:
tn.validate_children_and_parents_align(stree)
tn.validate_has_no_other_node_types(stree)

# Create SDFG
sdfg = stree.as_sdfg(
validate=validate,
Expand Down
47 changes: 19 additions & 28 deletions src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy

from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import utils
Expand Down Expand Up @@ -41,39 +43,28 @@ class PushVerticalMapDown(tn.ScheduleNodeVisitor):
// computation here (2)
"""

def _push_K_loop_in_IJ(self, node: tn.MapScope | tn.ForScope):
def visit_MapScope(self, scope: tn.MapScope) -> None:
if not scope.node.map.params[0].startswith("__k"):
return

# take refs before moving things around
parent = node
grandparent = node.parent
grandparent_children = node.parent.children
k_loop_index = utils.list_index(grandparent_children, parent)
parent = scope.parent
parent_children = scope.parent.children
k_loop_index = utils.list_index(parent_children, scope)

for child in node.children:
for child in scope.children:
if not isinstance(child, tn.MapScope):
raise NotImplementedError("We don't expect anything else than (IJ)-MapScopes here.")

# New loop with MapEntry (`node`) from parent and children from `child`
if isinstance(node, tn.MapScope):
new_loop = tn.MapScope(node=parent.node, children=child.children)
new_loop.parent = child
elif isinstance(node, tn.ForScope):
new_loop = node
node.children = child.children
node.parent = child
else:
raise ValueError(f"Unknown node of type {type(node)}")
child.children = [new_loop]
child.parent = grandparent
grandparent_children.insert(k_loop_index, child)
child.children = [
# New loop with MapEntry (`node`) from parent and children from `child`
tn.MapScope(
node=deepcopy(scope.node), children=[c for c in child.children], parent=child
)
]
child.parent = parent
parent_children.insert(k_loop_index, child)
k_loop_index += 1

# delete old (now unused) node
grandparent_children.remove(node)

def visit_MapScope(self, node: tn.MapScope):
if node.node.map.params[0].startswith("__k"):
self._push_K_loop_in_IJ(node)

def visit_ForScope(self, node: tn.ForScope):
if node.loop.loop_variable.startswith("__k"):
self._push_K_loop_in_IJ(node)
parent_children.remove(scope)
20 changes: 10 additions & 10 deletions src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ class SwapHorizontalMaps(tn.ScheduleNodeVisitor):
// computation here
"""

def visit_MapScope(self, node: tn.MapScope):
if node.node.params[0].startswith(Axis.J.iteration_symbol()) and node.node.params[
1
].startswith(Axis.I.iteration_symbol()):
def visit_MapScope(self, node: tn.MapScope) -> None:
params = node.node.map.params
first_param_J = params[0].startswith(Axis.J.iteration_symbol())

if first_param_J and params[1].startswith(Axis.I.iteration_symbol()):
# Swap params
tmp_index = node.node.params[0]
node.node.params[0] = node.node.params[1]
node.node.params[1] = tmp_index
param_J = params[0]
params[0] = params[1]
params[1] = param_J

# Swap ranges
tmp_bounds = node.node.range[0]
node.node.range[0] = node.node.range[1]
node.node.range[1] = tmp_bounds
node.node.map.range.reorder([1, 0])

self.visit(node.children)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pytest

from dace import nodes, subsets
from dace.properties import CodeBlock
from dace.sdfg.state import LoopRegion
from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import PushVerticalMapDown
Expand Down Expand Up @@ -40,6 +42,7 @@ def test_push_vertical_map_down():
flipper = PushVerticalMapDown()
flipper.visit(root)

# Assert that the k-loop has been pushed inside the ij-loop.
assert len(root.children) == 1
assert isinstance(root.children[0], tn.MapScope)
assert root.children[0].node.map.params == ["__i", "__j"]
Expand All @@ -49,6 +52,42 @@ def test_push_vertical_map_down():
assert root.children[0].children[0].node.map.params == ["__k"]


def test_push_vertical_map_down_for_scope():
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.ForScope(
loop=LoopRegion(
"vertical map",
loop_var="__k",
initialize_expr=CodeBlock("__k = 0"),
condition_expr=CodeBlock("__k < 10"),
update_expr=CodeBlock("__k += 1"),
),
children=[],
)
k_loop.parent = root
ij_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map("horizontal maps", ["__i", "__j"], subsets.Range.from_string("0:5,0:8"))
),
children=[],
)
ij_loop.parent = k_loop
k_loop.children.append(ij_loop)
root.children.append(k_loop)

flipper = PushVerticalMapDown()
flipper.visit(root)

# Assert that the tree hasn't changed.
assert len(root.children) == 1
assert isinstance(root.children[0], tn.ForScope)
assert root.children[0].loop.loop_variable == "__k"

assert len(root.children[0].children) == 1
assert isinstance(root.children[0].children[0], tn.MapScope)
assert root.children[0].children[0].node.map.params == ["__i", "__j"]


def test_push_vertical_map_down_multiple_horizontal_maps():
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.MapScope(
Expand Down Expand Up @@ -86,6 +125,8 @@ def test_push_vertical_map_down_multiple_horizontal_maps():
flipper = PushVerticalMapDown()
flipper.visit(root)

# Assert that the k-loop has been pushed inside the ij-loops,
# effectively duplicating the k-loop.
assert len(root.children) == 2

for index, child in enumerate(root.children):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest

from dace import nodes, subsets
from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import SwapHorizontalMaps

# Because "dace tests" filter by `requires_dace`, we still need to add the marker.
# This global variable adds the marker to all test functions in this module.
pytestmark = pytest.mark.requires_dace


def test_swap_horizontal_maps() -> None:
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map("vertical map", ["__k"], subsets.Range.from_string("0:10"))
),
children=[],
)
k_loop.parent = root
ji_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map("horizontal maps", ["__j", "__i"], subsets.Range.from_string("0:5,0:8"))
),
children=[],
)
ji_loop.parent = k_loop
k_loop.children.append(ji_loop)
root.children.append(k_loop)

flipper = SwapHorizontalMaps()
flipper.visit(root)

horizontal_maps = ji_loop.node.map
assert horizontal_maps.params[0] == "__i"
assert horizontal_maps.range[0] == (0, 7, 1)
assert horizontal_maps.params[1] == "__j"
assert horizontal_maps.range[1] == (0, 4, 1)
Loading