Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


def _node_sort_key(node: dace_nodes.Node) -> str:
"""Return a deterministic string key for sorting DaCe nodes.

Used to impose a stable iteration order on sets/collections of nodes,
preventing non-deterministic code generation caused by arbitrary set
iteration order.
"""
label = getattr(node, "data", getattr(node, "label", ""))
return f"{type(node).__name__}_{label}"


@dace_properties.make_properties
class MoveDataflowIntoIfBody(dace_transformation.SingleStateTransformation):
"""The transformation moves dataflow into the if branches.
Expand Down Expand Up @@ -320,7 +331,7 @@ def _replicate_dataflow_into_branch(
# Add the SDFGState to the key of the dictionary because we have to create
# new node for the different branches.
unique_old_nodes: list[dace_nodes.Node] = []
for old_node in nodes_to_move:
for old_node in sorted(nodes_to_move, key=_node_sort_key):
if (old_node, branch_state) in old_to_new_nodes_map:
continue
unique_old_nodes.append(old_node)
Expand Down Expand Up @@ -501,7 +512,7 @@ def _remove_outside_dataflow(
# Before we can clean the original nodes, we must clean the dataflow. If a
# node, that was relocated, has incoming connections we must remove them
# and the parent dataflow.
for node_to_remove in all_relocatable_dataflow:
for node_to_remove in sorted(all_relocatable_dataflow, key=_node_sort_key):
for iedge in list(state.in_edges(node_to_remove)):
if iedge.src in all_relocatable_dataflow:
continue
Expand Down Expand Up @@ -838,7 +849,7 @@ def filter_nodes(
def _partition_if_block(
self,
if_block: dace_nodes.NestedSDFG,
) -> Optional[tuple[set[str], set[str]]]:
) -> Optional[tuple[list[str], list[str]]]:
"""Check if `if_block` can be processed and partition the input connectors.

The function will check if `if_block` has the right structure, i.e. if it is
Expand All @@ -849,10 +860,10 @@ def _partition_if_block(
Returns:
If `if_block` is unsuitable the function will return `None`.
If `if_block` meets the structural requirements the function will return
two sets of strings. The first set contains the connectors that can be
relocated and the second one of the conditions that can not be relocated.
two sorted lists of strings. The first list contains the connectors that
can be relocated and the second one the connectors that can not be relocated.
Sorting ensures deterministic downstream iteration order.
"""
# TODO(phimuell): Change the return type to `tuple[list[str], list[str]]` and sort the connectors, such that the operation is deterministic.
# There shall only be one output and three inputs with given names.
if len(if_block.out_connectors.keys()) == 0:
return None
Expand Down Expand Up @@ -899,14 +910,14 @@ def _partition_if_block(
# So the ones that can be relocated were found exactly once. Zero would
# mean they can not be relocated and more than one means that we do not
# support it yet.
relocatable_connectors = {
relocatable_connectors = sorted(
conn_name for conn_name, conn_count in reference_count.items() if conn_count == 1
}
non_relocatable_connectors = {
)
non_relocatable_connectors = sorted(
conn_name
for conn_name in reference_count.keys()
if conn_name not in relocatable_connectors
}
)

if len(non_relocatable_connectors) == 0:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import collections
import copy
import uuid
import hashlib
import warnings
from typing import Any, Iterable, Optional, TypeAlias

Expand Down Expand Up @@ -1007,8 +1007,15 @@ def apply(

# This is the tasklet that we will put inside the map, note we have to do it
# this way to avoid some name clash stuff.
# Use a deterministic hash instead of uuid.uuid1() to ensure stable code
# generation across runs. The hash combines properties that are unique to
# this specific clone context.
_clone_key = (
f"{tasklet.label}_{tasklet.code.as_string}_{map_entry.label}_{access_node.data}"
)
_clone_hash = hashlib.md5(_clone_key.encode("utf-8")).hexdigest()
inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet(
name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}",
name=f"{tasklet.label}__clone_{_clone_hash}",
outputs=tasklet.out_connectors.keys(),
inputs=set(),
code=tasklet.code,
Expand Down
Loading