Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,10 @@ def freeze_origin_domain_sdfg(
wrapper_sdfg = SDFG("frozen_" + inner_sdfg.name)
state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state")

inputs = set()
outputs = set()
for node, parent in inner_sdfg.all_nodes_recursive():
if not isinstance(node, nodes.AccessNode) or inner_sdfg.arrays[node.data].transient:
continue

if node.has_reads(parent):
inputs.add(node.data)
if node.has_writes(parent):
outputs.add(node.data)
# gather inputs & outputs (i.e. reads/writes without transients)
inputs, outputs = inner_sdfg.read_and_write_sets()
inputs = set(filter(lambda name: not inner_sdfg.arrays[name].transient, inputs))
outputs = set(filter(lambda name: not inner_sdfg.arrays[name].transient, outputs))

# fake DebugInfo to avoid calls to `inspect`
nsdfg = state.add_nested_sdfg(inner_sdfg, inputs, outputs, debuginfo=DebugInfo(123456))
Expand Down
19 changes: 12 additions & 7 deletions src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy

from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import utils
Expand Down Expand Up @@ -54,14 +56,17 @@ def _push_K_loop_in_IJ(self, node: tn.MapScope | tn.ForScope):

# New loop with MapEntry (`node`) from parent and children from `child`
if isinstance(node, tn.MapScope):
new_loop = tn.MapScope(node=parent.node, children=child.children)
new_loop.parent = child
elif isinstance(node, tn.ForScope):
new_loop = node
node.children = child.children
node.parent = child
new_loop = tn.MapScope(
node=deepcopy(parent.node), children=[c for c in child.children], parent=child
)
else:
raise ValueError(f"Unknown node of type {type(node)}")
assert isinstance(node, tn.ForScope)
new_loop = tn.ForScope(
loop=deepcopy(parent.loop),
children=[c for c in child.children],
parent=child,
)

child.children = [new_loop]
child.parent = grandparent
grandparent_children.insert(k_loop_index, child)
Expand Down
20 changes: 10 additions & 10 deletions src/gt4py/cartesian/gtc/dace/passes/swap_horizontal_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ class SwapHorizontalMaps(tn.ScheduleNodeVisitor):
// computation here
"""

def visit_MapScope(self, node: tn.MapScope):
if node.node.params[0].startswith(Axis.J.iteration_symbol()) and node.node.params[
1
].startswith(Axis.I.iteration_symbol()):
def visit_MapScope(self, node: tn.MapScope) -> None:
params = node.node.map.params
first_param_J = params[0].startswith(Axis.J.iteration_symbol())

if first_param_J and params[1].startswith(Axis.I.iteration_symbol()):
# Swap params
tmp_index = node.node.params[0]
node.node.params[0] = node.node.params[1]
node.node.params[1] = tmp_index
param_J = params[0]
params[0] = params[1]
params[1] = param_J

# Swap ranges
tmp_bounds = node.node.range[0]
node.node.range[0] = node.node.range[1]
node.node.range[1] = tmp_bounds
node.node.map.range.reorder([1, 0])

self.visit(node.children)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pytest

from dace import nodes, subsets
from dace.properties import CodeBlock
from dace.sdfg.state import LoopRegion
from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import PushVerticalMapDown
Expand Down Expand Up @@ -49,6 +51,43 @@ def test_push_vertical_map_down():
assert root.children[0].children[0].node.map.params == ["__k"]


def test_push_vertical_map_down_for_scope():
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.ForScope(
loop=LoopRegion(
"vertical map",
loop_var="__k",
initialize_expr=CodeBlock("__k = 0"),
condition_expr=CodeBlock("__k < 10"),
update_expr=CodeBlock("__k += 1"),
),
children=[],
)
k_loop.parent = root
ij_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map("horizontal maps", ["__i", "__j"], subsets.Range.from_string("0:5,0:8"))
),
children=[],
)
ij_loop.parent = k_loop
k_loop.children.append(ij_loop)
root.children.append(k_loop)

flipper = PushVerticalMapDown()
flipper.visit(root)

tn.validate_children_and_parents_align(root)

assert len(root.children) == 1
assert isinstance(root.children[0], tn.MapScope)
assert root.children[0].node.map.params == ["__i", "__j"]

assert len(root.children[0].children) == 1
assert isinstance(root.children[0].children[0], tn.ForScope)
assert root.children[0].children[0].loop.loop_variable == "__k"


def test_push_vertical_map_down_multiple_horizontal_maps():
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.MapScope(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest

from dace import nodes, subsets
from dace.sdfg.analysis.schedule_tree import treenodes as tn

from gt4py.cartesian.gtc.dace.passes import SwapHorizontalMaps

# Because "dace tests" filter by `requires_dace`, we still need to add the marker.
# This global variable adds the marker to all test functions in this module.
pytestmark = pytest.mark.requires_dace


def test_swap_horizontal_maps() -> None:
root = tn.ScheduleTreeRoot(name="tester", children=[])
k_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map("vertical map", ["__k"], subsets.Range.from_string("0:10:2"))
),
children=[],
)
k_loop.parent = root
ji_loop = tn.MapScope(
node=nodes.MapEntry(
map=nodes.Map(
"horizontal maps", ["__j", "__i"], subsets.Range([(0, 4, 1), (0, 7, 2, 2)])
)
),
children=[],
)
ji_loop.parent = k_loop
k_loop.children.append(ji_loop)
root.children.append(k_loop)

flipper = SwapHorizontalMaps()
flipper.visit(root)

horizontal_maps = ji_loop.node.map
assert horizontal_maps.params[0] == "__i"
assert horizontal_maps.range[0] == (0, 7, 2)
assert horizontal_maps.range.tile_sizes[0] == 2
assert horizontal_maps.params[1] == "__j"
assert horizontal_maps.range[1] == (0, 4, 1)
assert horizontal_maps.range.tile_sizes[1] == 1
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading