Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def gt_auto_optimize(
optimization_hooks = optimization_hooks or {}

with dace.config.temporary_config():
# Enable deterministic SDFG sorting for reproducible code generation.
dace.Config.set("compiler", "sdfg_alphabetical_sorting", value=True)

# Do not store which transformations were applied inside the SDFG.
dace.Config.set("store_history", value=False)

Expand Down Expand Up @@ -434,6 +437,8 @@ def _gt_auto_process_top_level_maps(
The function assumes that `gt_simplify()` has been called on the SDFG
before it is passed to this function.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

# NOTE: Inside this function we have to disable the consolidation of edges.
# This is because it might block the application of `SpliAccessNode`. As
Expand Down Expand Up @@ -690,6 +695,8 @@ def _gt_auto_process_dataflow_inside_maps(
over a constant range, e.g. the number of neighbours, which is known at compile
time, so the compiler will fully unroll them anyway.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

# Separate Tasklets into dependent and independent parts to promote data
# reusability. It is important that this step has to be performed before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def gt_eliminate_dead_dataflow(
Todo:
Implement a better way of applying the `DeadMemletElimination` transformation.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

find_single_use_data = dace_analysis.FindSingleUseData()
single_use_data = find_single_use_data.apply_pass(sdfg, None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def restrict_fusion_to_newly_created_maps_horizontal(
# Now try to fuse the maps together, but restrict them that at least one map
# needs to be new.
# TODO(phimuell): Improve this by replacing it by an explicit loop.
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()
sdfg.apply_transformations_repeated(
[
gtx_transformations.MapFusionVertical(
Expand Down Expand Up @@ -791,6 +793,8 @@ def gt_remove_trivial_gpu_maps(
Todo: Improve this function.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

# First we try to promote and fuse them with other non-trivial maps.
sdfg.apply_transformations_once_everywhere(
Expand Down Expand Up @@ -828,6 +832,8 @@ def restrict_to_trivial_gpu_maps(
return True

# TODO(phimuell): Replace this with a more performant loop.
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()
sdfg.apply_transformations_repeated(
[
gtx_transformations.MapFusionVertical(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def gt_horizontal_map_split_fusion(
validate: Perform validation during the steps.
validate_all: Perform extensive validation.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

transformations = [
HorizontalSplitMapRange(
Expand Down Expand Up @@ -216,6 +218,9 @@ def gt_vertical_map_split_fusion(
- Due to a bug in the transformation, not all Maps, that were created by
the splitting were fused. Especially "chains" might still be present.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

if single_use_data is None:
find_single_use_data = dace_analysis.FindSingleUseData()
single_use_data = find_single_use_data.apply_pass(sdfg, None)
Expand Down Expand Up @@ -795,4 +800,6 @@ def _restrict_fusion_to_newly_created_maps(
trafo._single_use_data = self._single_use_data

# This is not efficient, but it is currently the only way to run it
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()
sdfg.apply_transformations_repeated(trafo, validate=False, validate_all=False)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def gt_set_iteration_order(
validate: Perform validation at the end of the function.
validate_all: Perform validation also on intermediate steps.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

return sdfg.apply_transformations_once_everywhere(
MapIterationOrder(
unit_strides_dims=unit_strides_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


def _node_sort_key(node: dace_nodes.Node) -> str:
"""Return a deterministic string key for sorting DaCe nodes.

Used to impose a stable iteration order on sets/collections of nodes,
preventing non-deterministic code generation caused by arbitrary set
iteration order.
"""
label = getattr(node, "data", getattr(node, "label", ""))
return f"{type(node).__name__}_{label}"


@dace_properties.make_properties
class MoveDataflowIntoIfBody(dace_transformation.SingleStateTransformation):
"""The transformation moves dataflow into the if branches.
Expand Down Expand Up @@ -320,7 +331,7 @@ def _replicate_dataflow_into_branch(
# Add the SDFGState to the key of the dictionary because we have to create
# new node for the different branches.
unique_old_nodes: list[dace_nodes.Node] = []
for old_node in nodes_to_move:
for old_node in sorted(nodes_to_move, key=_node_sort_key):
if (old_node, branch_state) in old_to_new_nodes_map:
continue
unique_old_nodes.append(old_node)
Expand Down Expand Up @@ -838,7 +849,7 @@ def filter_nodes(
def _partition_if_block(
self,
if_block: dace_nodes.NestedSDFG,
) -> Optional[tuple[set[str], set[str]]]:
) -> Optional[tuple[list[str], list[str]]]:
"""Check if `if_block` can be processed and partition the input connectors.

The function will check if `if_block` has the right structure, i.e. if it is
Expand All @@ -849,10 +860,10 @@ def _partition_if_block(
Returns:
If `if_block` is unsuitable the function will return `None`.
If `if_block` meets the structural requirements the function will return
two sets of strings. The first set contains the connectors that can be
relocated and the second one of the conditions that can not be relocated.
two sorted lists of strings. The first list contains the connectors that
can be relocated and the second one the connectors that can not be relocated.
Sorting ensures deterministic downstream iteration order.
"""
# TODO(phimuell): Change the return type to `tuple[list[str], list[str]]` and sort the connectors, such that the operation is deterministic.
# There shall only be one output and three inputs with given names.
if len(if_block.out_connectors.keys()) == 0:
return None
Expand Down Expand Up @@ -899,14 +910,14 @@ def _partition_if_block(
# So the ones that can be relocated were found exactly once. Zero would
# mean they can not be relocated and more than one means that we do not
# support it yet.
relocatable_connectors = {
relocatable_connectors = sorted(
conn_name for conn_name, conn_count in reference_count.items() if conn_count == 1
}
non_relocatable_connectors = {
)
non_relocatable_connectors = sorted(
conn_name
for conn_name in reference_count.keys()
if conn_name not in relocatable_connectors
}
)

if len(non_relocatable_connectors) == 0:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def gt_multi_state_global_self_copy_elimination(
The function will also run `MultiStateGlobalSelfCopyElimination2`, but the
results are merged together.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

transforms = [
gtx_transformations.MultiStateGlobalSelfCopyElimination(),
gtx_transformations.MultiStateGlobalSelfCopyElimination2(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def gt_remove_copy_chain(
single_use_data: Which data descriptors are used only once.
If not passed the function will run `FindSingleUseData`.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

# To ensures that the `{src,dst}_subset` are properly set, run initialization.
# See [issue 1703](https://github.com/spcl/dace/issues/1703)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import collections
import copy
import uuid
import hashlib
import warnings
from typing import Any, Iterable, Optional, TypeAlias

Expand Down Expand Up @@ -83,6 +83,9 @@ def gt_simplify(
elimination at the end. The whole process is run inside a loop that ensures
that `gt_simplify()` results in a fix point.
"""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

# Ensure that `skip` is a `set`
skip = gtx_transformations.constants.GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip)

Expand Down Expand Up @@ -476,6 +479,9 @@ def gt_reduce_distributed_buffering(
validate_all: bool = False,
) -> Optional[dict[dace.SDFG, dict[dace.SDFGState, set[str]]]]:
"""Removes distributed write back buffers."""
# Sort SDFG for deterministic pattern matching.
sdfg.sort_sdfg_alphabetically()

pipeline = dace_ppl.Pipeline([DistributedBufferRelocator()])
all_result = {}

Expand Down Expand Up @@ -1007,8 +1013,15 @@ def apply(

# This is the tasklet that we will put inside the map, note we have to do it
# this way to avoid some name clash stuff.
# Use a deterministic hash instead of uuid.uuid1() to ensure stable code
# generation across runs. The hash combines properties that are unique to
# this specific clone context.
_clone_key = (
f"{tasklet.label}_{tasklet.code.as_string}_{map_entry.label}_{connector_name}_{access_node.data}"
)
_clone_hash = hashlib.md5(_clone_key.encode("utf-8")).hexdigest()
inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet(
name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}",
name=f"{tasklet.label}__clone_{_clone_hash}",
outputs=tasklet.out_connectors.keys(),
inputs=set(),
code=tasklet.code,
Expand Down
Loading