diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 3b0f47c02c..66e3bf5fd4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -6,13 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import collections import copy import functools from typing import Any, Optional import dace from dace import ( - data as dace_data, dtypes as dace_dtypes, properties as dace_properties, subsets as dace_sbs, @@ -125,11 +125,13 @@ def can_be_applied( return False # Test if the `if_block` is valid. This will also give us the names. - if_block_spec = self._partition_if_block(if_block) + if_block_spec = self._partition_if_block(sdfg, if_block) if if_block_spec is None: return False + relocatable_connectors, non_relocatable_connectors, connector_usage_location = if_block_spec # Compute the dataflow that is relocated. + # NOTE: That the nodes sets are not sorted in any way, however, the raw_relocatable_dataflow, non_relocatable_dataflow = ( { conn_name: gtx_transformations.utils.find_upstream_nodes( @@ -140,7 +142,7 @@ def can_be_applied( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in [relocatable_connectors, non_relocatable_connectors] ) relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -148,11 +150,10 @@ def can_be_applied( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) - - # If no branch has something to inline then we are done. - if all(len(rel_df) == 0 for rel_df in relocatable_dataflow.values()): + if len(relocatable_dataflow) == 0: return False # Check if relatability is possible. @@ -172,18 +173,17 @@ def can_be_applied( # transformation is applied in a loop until it applies nowhere anymore. # NOTE: This is a restriction due to the current implementation. if not self.ignore_upstream_blocks: - for reloc_dataflow in relocatable_dataflow.values(): - if any( - self._has_if_block_relocatable_dataflow( - sdfg=sdfg, - state=graph, - upstream_if_block=upstream_if_block, - enclosing_map=enclosing_map, - ) - for upstream_if_block in reloc_dataflow - if isinstance(upstream_if_block, dace_nodes.NestedSDFG) - ): - return False + if any( + self._has_if_block_relocatable_dataflow( + sdfg=sdfg, + state=graph, + upstream_if_block=upstream_if_block, + enclosing_map=enclosing_map, + ) + for upstream_if_block in relocatable_dataflow + if isinstance(upstream_if_block, dace_nodes.NestedSDFG) + ): + return False return True @@ -193,9 +193,10 @@ def apply( sdfg: dace.SDFG, ) -> None: if_block: dace_nodes.NestedSDFG = self.if_block - if_block_spec = self._partition_if_block(if_block) - assert if_block_spec is not None enclosing_map = graph.scope_dict()[if_block] + relocatable_connectors, non_relocatable_connectors, connector_usage_location = ( + self._partition_if_block(sdfg, if_block) # type: ignore[misc] # Guaranteed to be not None. + ) # Find the dataflow that should be relocated. raw_relocatable_dataflow, non_relocatable_dataflow = ( @@ -208,54 +209,60 @@ def apply( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in [relocatable_connectors, non_relocatable_connectors] ) - relocatable_dataflow = self._filter_relocatable_dataflow( + relocatable_dataflow: set[dace_nodes.Node] = self._filter_relocatable_dataflow( sdfg=sdfg, state=graph, if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) - # Create a mapping from connector names to the corresponding AccessNode inside the branch and the branch state. - # This is necessary to properly patch the dataflow inside the branch, i.e. to connect it to the global AccessNode - # corresponding to the connector if necessary. We need to gather this information for all the connectors because - # a node in the dataflow of one connector might be the global AccessNode of another connector and we have to handle - # it properly in `_replicate_dataflow_into_branch`. - conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} - for conn_name in relocatable_dataflow.keys(): - conn_name_to_access_node_map[conn_name] = self._find_branch_for( - if_block=if_block, - connector=conn_name, - ) + # Bring the nodes in a deterministic order, which is induced by the underlying state. + # NOTE: The following key function is equivalent to use `lambda n: graph.node_id(n)` + # but instead of O[N^2] it is O[N]. + node_keys = {node: i for i, node in enumerate(graph.nodes())} + nodes_to_move = sorted(relocatable_dataflow, key=lambda n: node_keys[n]) + + # For each node we have to find out in which state inside the `if_block` it will + # end up. `relocation_destination` has a fixed order. + relocation_destination: dict[dace_nodes.Node, dace.SDFGState] = {} + for node_to_move in nodes_to_move: + # Although `node_top_move` could be reached through different connectors + # they are all associated to the same branch. + target_state: Optional[dace.SDFGState] = None + for conn, raw_reloc_dataflow_of_conn in raw_relocatable_dataflow.items(): + if node_to_move in raw_reloc_dataflow_of_conn: + target_state = connector_usage_location[conn][0] + break + else: + raise ValueError("Could not find node '{node_to_move}'") + assert target_state is not None + relocation_destination[node_to_move] = target_state - # Create a map of the old to the new nodes to keep track of the old nodes copied - # and their corresponding new nodes. The map should be per branch of the ConditionalBlock - # because there could be a node that has to be copied in all branches. Since the - # `relocatable_dataflow` nodes are not disjoint we need this mapping to avoid copying the same - # node multiple times and to properly connect the copied nodes. - old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() - # Finally relocate the dataflow - for conn_name, nodes_to_move in relocatable_dataflow.items(): - self._replicate_dataflow_into_branch( - state=graph, - sdfg=sdfg, - if_block=if_block, - enclosing_map=enclosing_map, - nodes_to_move=nodes_to_move, - connector=conn_name, - conn_name_to_access_node_map=conn_name_to_access_node_map, - old_to_new_nodes_map=old_to_new_nodes_map, - ) + # Relocate the dataflow. + self._replicate_dataflow_into_branch( + state=graph, + sdfg=sdfg, + if_block=if_block, + enclosing_map=enclosing_map, + relocation_destination=relocation_destination, + connector_usage_location=connector_usage_location, + ) - self._update_symbol_mapping(if_block, sdfg) + # Must be performed after relocation. + self._update_symbol_mapping( + sdfg=sdfg, + if_block=if_block, + ) self._remove_outside_dataflow( sdfg=sdfg, state=graph, - relocatable_dataflow=relocatable_dataflow, + relocation_destination=relocation_destination, ) # Because we relocate some node it seems that DaCe gets a bit confused. @@ -273,143 +280,81 @@ def _replicate_dataflow_into_branch( state: dace.SDFGState, if_block: dace_nodes.NestedSDFG, enclosing_map: dace_nodes.MapEntry, - nodes_to_move: set[dace_nodes.Node], - connector: str, - conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], - old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node], + relocation_destination: dict[dace_nodes.Node, dace.SDFGState], + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], ) -> None: - """Replicate the dataflow in `nodes_to_move` from `state` into `if_block`. - - First the function will determine into which branch, inside `if_block`, - the dataflow has to be replicated. It will then copy the dataflow, nodes - listed in `nodes_to_move` and insert them into that state. - The function will then create the edges to connect them in the same way - as they where outside. If there is an outer data dependency, for example - a read to a global memory, then the function will patch that inside the - `if_block`. - At the end the function will remove the `connector`, but it will not remove - the original dataflow. - In case of nodes existing in multiple `nodes_to_move` sets coming from multiple - connectors, the function will only copy the necessary nodes and edges only once - based on the `old_to_new_nodes_map` keys and values. In case a connector to the - NestedSDFG is exists in the dataflow of another connector, the function - will remove the connection of the original connector and replace the global - AccessNode of the NestedSDFG with a temporary one and remove the original connector. + """Replicate the dataflow in `relocation_destination` into `if_block`. + + The function will replicate the dataflow listed in `relocatable_connectors.keys()`, + that needs + to be connected, in some way, to the `if_block`. It will remove the connectors + that are no longer needed, but it will not remove the original dataflow nor + update the symbol mapping. Args: sdfg: The sdfg that we process, the one that contains `state`. state: The state we operate on, the one that contains `if_block`. if_block: The `if_block` into which we inline. enclosing_map: The enclosing map. - nodes_to_move: The list of nodes that should be removed. - connector: The connector that should be inlined. - conn_name_to_access_node_map: A mapping from connector names to the - corresponding AccessNode inside the branch and the branch state. - old_to_new_nodes_map: A mapping from the old nodes to the new nodes. - The keys of the mapping are tuples of the old node and the branch - state for which the new node was created. The values are the new nodes. + nodes_to_move: The list of nodes that should be moved. + connector_usage_location: Maps connector names to the state and AccessNode + where they appear inside the nested SDFG. """ - # Nothing to relocate nothing to do. - if len(nodes_to_move) == 0: - return - - inner_sdfg: dace.SDFG = if_block.sdfg - branch_state, connector_node = conn_name_to_access_node_map[connector] - - # Replicate the nodes and store them in the `old_to_new_nodes_map` mapping. - # Add the SDFGState to the key of the dictionary because we have to create - # new node for the different branches. - unique_old_nodes: list[dace_nodes.Node] = [] - for old_node in nodes_to_move: - if (old_node, branch_state) in old_to_new_nodes_map: - continue - unique_old_nodes.append(old_node) - copy_of_old_node = copy.deepcopy(old_node) - old_to_new_nodes_map[(old_node, branch_state)] = copy_of_old_node - branch_state.add_node(copy_of_old_node) - - # There might be AccessNodes inside `nodes_to_move`, we now have to make sure - # that they are present inside the nested ones. By our base assumption they - # are transients and single use, because they are only used in one place. Make - # sure that we don't add new nodes if the nodes are already in the arrays of - # the inner SDFG. - if not isinstance(old_node, dace_nodes.AccessNode): + inner_sdfg = if_block.sdfg + + # Maps old nodes to the new relocated nodes inside the `if_block`. Note that + # the state _inside_ the `if_block` is part of the key. This is needed to + # handle the "outside Map data" which must be mapped into multiple states. + node_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() + rename_map: dict[tuple[str, dace.SDFGState], str] = dict() + + # Data that has been fully mapped into the `if_block` and its name inside it. + fully_mapped_in_data: dict[str, set[str]] = collections.defaultdict(set) + for if_iedge in state.in_edges(if_block): + if if_iedge.data.is_empty(): continue - if old_node.data in inner_sdfg.arrays: - continue - assert sdfg.arrays[old_node.data].transient - # TODO(phimuell): Handle the case we need to rename something. - inner_sdfg.add_datadesc( - old_node.data, - sdfg.arrays[old_node.data].clone(), - find_new_name=False, - ) - - # Now add the edges between the edges that have been replicated inside the - # branch state, these are the outgoing edges. - # Now add the outgoing edges between the replicated nodes inside the branch state. - # The data dependencies (incomming edges) of the nodes, i.e. the not relocated dataflow, - # are still missing. - for node in unique_old_nodes: - for oedge in state.out_edges(node): - if oedge.dst is if_block: - if oedge.dst_conn == connector: - # This connection maps the outside data into the nested SDFG, thus its destination is technically the same data. - # TODO(phimuell): Make subsets complete. - # TODO(phimuell): Check if this Memlet is always correct, especially in case of slicing. - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], - oedge.src_conn, - connector_node, - None, - dace.Memlet.from_memlet(oedge.data), - ) - else: - assert oedge.dst_conn in conn_name_to_access_node_map - # Some of the nodes_to_move are also in the dataflow of some other connector whose dataflow - # we should also move. Since we already moved the dataflow of the other connector we connect - # the dataflow to the global AccessNode corresponding to the other connector - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], - oedge.src_conn, - conn_name_to_access_node_map[oedge.dst_conn][1], - None, - dace.Memlet.from_memlet(oedge.data), - ) - # Handle the other connector AccessNode and connector - inner_access_node_of_connector_name = conn_name_to_access_node_map[ - oedge.dst_conn - ][1].data - if not inner_sdfg.arrays[inner_access_node_of_connector_name].transient: - inner_sdfg.arrays[inner_access_node_of_connector_name].transient = True - if_block.remove_in_connector(inner_access_node_of_connector_name) - else: - # Restore a connection between two nodes that were relocated. - assert oedge.dst in nodes_to_move - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], - oedge.src_conn, - old_to_new_nodes_map[(oedge.dst, branch_state)], - oedge.dst_conn, - dace.Memlet.from_memlet(oedge.data), - ) - - # Now we have to satisfy the data dependencies, i.e. forward all nodes that - # could not have been moved inside `if_block` but are still needed to compute - # the final result. We find them by scanning the input edges of the nodes - # that have been relocated. - for node in unique_old_nodes: - for iedge in state.in_edges(node): - if iedge.src in nodes_to_move: - # Inner data dependency, there is nothing to do and the edge was - # created above. + outer_data = if_iedge.data.data + mapped_in_range = if_iedge.data.subset # Is always `.subset`. + outer_desc = sdfg.arrays[outer_data] + if mapped_in_range.covers(dace_sbs.Range.from_array(outer_desc)) == True: # noqa: E712 [true-false-comparison] # SymPy comparison + fully_mapped_in_data[outer_data].add(if_iedge.dst_conn) + + # Replicate the nodes into the `if_block` and create the needed data The + # "outside Map data" will be handled when we handle the incoming edges. + for origin_node, branch_state in relocation_destination.items(): + reloc_node = copy.deepcopy(origin_node) + node_map[(origin_node, branch_state)] = reloc_node + branch_state.add_node(reloc_node) + + # If we relocate an AccessNode, we have to make sure that the data descriptor + # is also added to the nested SDFG. We allow renaming of data containers + # but we do not allow renaming of symbols, this is checked by + # `_check_for_data_and_symbol_conflicts()`. + if isinstance(origin_node, dace_nodes.AccessNode): + assert sdfg.arrays[origin_node.data].transient + # TODO(phimuell): Handle the case we need to rename something. + new_data_name = inner_sdfg.add_datadesc( + origin_node.data, + sdfg.arrays[origin_node.data].clone(), + find_new_name=True, + ) + reloc_node.data = new_data_name + rename_map[(origin_node.data, branch_state)] = new_data_name + + # We now create the mapped nodes, i.e. the nodes that are not relocated but + # have to be put inside the `if_block`. We find them by looking at the input + # edges, that do not lead to a node that is relocated. Connections between + # relocated nodes are handled later. + for origin_node, branch_state in relocation_destination.items(): + for iedge in state.in_edges(origin_node): + if iedge.src in relocation_destination: + # Dependency between two relocated nodes: Handled below. continue - if iedge.data.is_empty(): - # Empty Memlets are there to maintain some order relation, "happens - # before". Depending on the situation we can remove or have to - # recreate them. The case where the connection comes from a - # node within the relocated dataflow is handled above. - assert iedge.src is enclosing_map + elif iedge.data.is_empty(): + # This is an empty Memlet that is between a node that is relocated + # and a node that is not relocated. Because we move the destination + # of the edge into the `if_block` the "happens before" relation + # is automatically handled and this edge is no longer needed. continue # Now we have to figuring out where the data is coming from, since @@ -418,106 +363,186 @@ def _replicate_dataflow_into_branch( # The data is coming from outside the Map scope, i.e. not defined # inside the Map scope, so we have to trace it back. memlet_path = state.memlet_path(iedge) - outer_data = memlet_path[0].src + outer_node = memlet_path[0].src else: # The data is defined somewhere in the Map scope itself. - outer_data = iedge.src + outer_node = iedge.src + # TODO(phimuell): It is possible that this does not lead to an # AccessNode on the outside, but to something inside the Map scope # such as the MapExit of an inner map. To handle such a case we need # to construct the set of nodes to move differently, i.e. # considering this case already there. - if not isinstance(outer_data, dace_nodes.AccessNode): + if not isinstance(outer_node, dace_nodes.AccessNode): raise NotImplementedError() - assert not gtx_transformations.utils.is_view(outer_data, sdfg) - - # If the data is not yet available in the inner SDFG made - # patch it through. - if outer_data.data not in inner_sdfg.arrays: - inner_desc = sdfg.arrays[outer_data.data].clone() - inner_desc.transient = False - # TODO(phimuell): Handle the case we need to rename something. - inner_sdfg.add_datadesc(outer_data.data, inner_desc, False) - # TODO(phimeull): We pass the whole data inside the SDFG. - # Find out if there are cases where this is wrong. + assert not gtx_transformations.utils.is_view(outer_node, sdfg) + + outer_data = outer_node.data + outer_desc = sdfg.arrays[outer_data] + + if (outer_node, branch_state) in node_map: + # The node is already mapped into this state. + assert (outer_data, branch_state) in rename_map + assert not node_map[(outer_node, branch_state)].desc(inner_sdfg).transient + pass + + elif outer_data in fully_mapped_in_data: + # The data has already been mapped into the `if_block`, but not in + # `branch_state`. We first look if the state contains an AccessNode + # referring to that data. + outer_aliases = fully_mapped_in_data[outer_data] + candidate_nodes: list[dace_nodes.AccessNode] = sorted( + ( + dnode + for dnode in branch_state.data_nodes() + if dnode.data in outer_aliases + ), + key=lambda dnode: dnode.data, + ) + + if len(candidate_nodes) == 0: + # There is no AccessNode in the state so we have to create one. + inner_data = sorted(outer_aliases)[0] + inner_node = branch_state.add_access(inner_data) + + else: + # There is an AccessNode in the state. To handle some legal + # but unlikely case we check that nodes we found are all + # source nodes. We have to do this to prevent read-write + # conflicts. + candidate_source_nodes = [ + dnode for dnode in candidate_nodes if branch_state.in_degree(dnode) == 0 + ] + if len(candidate_source_nodes) != len(candidate_nodes): + raise NotImplementedError() + + # We take the first node, since they are sorted it is deterministic. + inner_node = candidate_source_nodes[0] + + assert (outer_data, branch_state) not in rename_map + assert not inner_sdfg.arrays[inner_node.data].transient + rename_map[(outer_data, branch_state)] = inner_node.data + node_map[(outer_node, branch_state)] = inner_node + + else: + # The data is not already mapped in and is also unknown. + # Here we rely on that we do not have to perform symbol renaming. + inner_data = inner_sdfg.add_datadesc( + outer_data, + outer_desc.clone(), + find_new_name=True, + ) + inner_sdfg.arrays[inner_data].transient = False + state.add_edge( iedge.src, iedge.src_conn, if_block, - outer_data.data, - dace.Memlet( - data=outer_data.data, subset=dace_sbs.Range.from_array(inner_desc) - ), + inner_data, + dace.Memlet.from_array(outer_data, outer_desc), ) - if_block.add_in_connector(outer_data.data) - else: - # This is the case that we found a node, that refers to data that - # was already patched into the `if_block`. We would have to remove - # this, but since this function just replicates the dataflow, - # it will not do that. Instead we postpone this to the cleanup - # phase, see `_remove_outside_dataflow()`. - pass + if_block.add_in_connector(inner_data) - if (outer_data, branch_state) not in old_to_new_nodes_map: - assert all( - outer_data.data != mapped_node.data - for mapped_node, mapped_branch_state in old_to_new_nodes_map.keys() - if isinstance(mapped_node, dace_nodes.AccessNode) - and mapped_branch_state == branch_state - ) - assert outer_data.data in inner_sdfg.arrays - assert not inner_sdfg.arrays[outer_data.data].transient - old_to_new_nodes_map[(outer_data, branch_state)] = branch_state.add_access( - outer_data.data, copy.copy(outer_data.debuginfo) - ) + inner_node = branch_state.add_access(inner_data) + rename_map[(outer_data, branch_state)] = inner_node.data + node_map[(outer_node, branch_state)] = inner_node + fully_mapped_in_data[outer_data].add(inner_data) # Now create the edge in the inner state. - branch_state.add_edge( - old_to_new_nodes_map[(outer_data, branch_state)], + new_edge = branch_state.add_edge( + node_map[(outer_node, branch_state)], None, - old_to_new_nodes_map[(iedge.dst, branch_state)], + node_map[(iedge.dst, branch_state)], iedge.dst_conn, copy.deepcopy(iedge.data), ) + new_edge.data.data = rename_map[(outer_data, branch_state)] + + # Now create the edges between the relocated nodes, which are all the outgoing + # edges, the `if_block` is handled as a special relocated node and its + # connectors (but not the edges) are removed to. + # NOTE: This loop can not be fused with the one above and must run after it. + for origin_node, branch_state in relocation_destination.items(): + for oedge in state.out_edges(origin_node): + if oedge.dst is if_block: + # This defines the "argument" to the nested SDFG. This means that + # the new destination now is the single node inside `if_block` + # that represents the argument. + assert not inner_sdfg.arrays[oedge.dst_conn].transient + assert branch_state is connector_usage_location[oedge.dst_conn][0] + assert isinstance(oedge.src, dace_nodes.AccessNode) + assert oedge.data.wcr is None and oedge.data.other_subset is None - # The old connector name is no longer valid. - inner_sdfg.arrays[connector].transient = True - if_block.remove_in_connector(connector) + branch_state.add_edge( + node_map[(oedge.src, branch_state)], + oedge.src_conn, + connector_usage_location[oedge.dst_conn][1], + None, + dace.Memlet( + data=rename_map[(oedge.data.data, branch_state)], + subset=oedge.data.subset, # Is always subset. + other_subset=dace_sbs.Range.from_array( + inner_sdfg.arrays[oedge.dst_conn] + ), + volume=oedge.data.volume, + dynamic=oedge.data.dynamic, + ), + ) + + # The inner data is no longer a global but has become a transient. + assert oedge.dst_conn in if_block.in_connectors + inner_sdfg.arrays[oedge.dst_conn].transient = True + if_block.remove_in_connector(oedge.dst_conn) + + else: + # Edges that do not go to the `if_block` must lead to a node + # that is also relocated. + assert origin_node in relocation_destination + new_oedge = branch_state.add_edge( + node_map[(oedge.src, branch_state)], + oedge.src_conn, + node_map[(oedge.dst, branch_state)], + oedge.dst_conn, + dace.Memlet.from_memlet(oedge.data), + ) + if not oedge.data.is_empty(): + new_oedge.data.data = rename_map[(oedge.data.data, branch_state)] def _remove_outside_dataflow( self, sdfg: dace.SDFG, state: dace.SDFGState, - relocatable_dataflow: dict[str, set[dace_nodes.Node]], + relocation_destination: dict[dace_nodes.Node, dace.SDFGState], ) -> None: """Removes the original dataflow, that has been relocated. The function will also remove data containers that are no longer in use. """ - all_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() - ) - # Before we can clean the original nodes, we must clean the dataflow. If a - # node, that was relocated, has incoming connections we must remove them - # and the parent dataflow. - for node_to_remove in all_relocatable_dataflow: + # Clean up the dataflow first, before removing the nodes. + for node_to_remove in relocation_destination: + # Create all "interface edges", i.e. connecting a relocated node with one + # that is not. This is needed to properly remove dangling Memlet paths. for iedge in list(state.in_edges(node_to_remove)): - if iedge.src in all_relocatable_dataflow: + if iedge.src in relocation_destination: continue dace_sutils.remove_edge_and_dangling_path(state, iedge) if isinstance(node_to_remove, dace_nodes.AccessNode): + # NOTE: We can remove the data here, because by assumption data that is + # referred to by an AccessNode inside a Map is single use data and + # used nowhere else. + # NOTE: This will temporarily create an invalid SDFG. assert node_to_remove.desc(sdfg).transient sdfg.remove_data(node_to_remove.data, validate=False) # Remove the original nodes (data descriptors were deleted in the loop above). - state.remove_nodes_from(all_relocatable_dataflow) + state.remove_nodes_from(relocation_destination.keys()) def _update_symbol_mapping( self, + sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, - parent: dace.SDFG, ) -> None: """Updates the symbol mapping of the nested SDFG. @@ -525,14 +550,17 @@ def _update_symbol_mapping( are available in the parent SDFG. """ symbol_mapping = if_block.symbol_mapping - missing_symbols = [ms for ms in if_block.sdfg.free_symbols if ms not in symbol_mapping] + missing_symbols = sorted( + (ms for ms in if_block.sdfg.free_symbols if ms not in symbol_mapping), + key=lambda sym: str(sym), + ) symbol_mapping.update({s: s for s in missing_symbols}) if_block.symbol_mapping = symbol_mapping # Performs conversion. # Add new global symbols to nested SDFG. # The code is based on `SDFGState.add_nested_sdfg()`. if_block_symbols = if_block.sdfg.symbols - parent_symbols = parent.symbols + parent_symbols = sdfg.symbols for new_sym in missing_symbols: if new_sym in if_block_symbols: # The symbol is already known, so we check that it is the same type as in the @@ -557,7 +585,7 @@ def _check_for_data_and_symbol_conflicts( self, sdfg: dace.SDFG, state: dace.SDFGState, - relocatable_dataflow: dict[str, set[dace_nodes.Node]], + relocatable_dataflow: set[dace_nodes.Node], if_block: dace_nodes.NestedSDFG, enclosing_map: dace_nodes.MapEntry, ) -> bool: @@ -568,46 +596,45 @@ def _check_for_data_and_symbol_conflicts( # It is probably not a problem, because of the scopes DaCe adds when # generating the C++ code. - # Create a subgraph to compute the free symbols, i.e. the symbols that - # need to be supplied from the outside. However, this are not all. - # Note, just adding some "well chosen" nodes to the set will not work. - all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() - ) + # This will give us the "internal symbols" that need to be mapped into `if_block`. + # It does not include all symbols, see bellow. requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( - state, all_relocated_dataflow + state, relocatable_dataflow ).free_symbols + assert all(isinstance(sym, str) for sym in requiered_symbols) - inner_data_names = if_block.sdfg.arrays.keys() - for node_to_check in all_relocated_dataflow: - if ( - isinstance(node_to_check, dace_nodes.AccessNode) - and node_to_check.data in inner_data_names - ): - # There is already a data descriptor that is used on the inside as on - # the outside. Thus we would have to perform some renaming, which we - # currently do not. - # TODO(phimell): Handle this case. - return False - + # The internal symbols missing the symbols that are needed by the nodes that + # are just mapped into the `if_block` as well as the connections that connects + # relocated and mapped nodes. + for node_to_check in relocatable_dataflow: for iedge in state.in_edges(node_to_check): - src_node = iedge.src - if src_node not in all_relocated_dataflow: - # This means that `src_node` is not relocated but mapped into the - # `if` block. This means that `edge` is replicated as well. - # NOTE: This code is based on the one found in `DataflowGraphView`. - # TODO(phimuell): Do we have to inspect the full Memlet path here? - assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map - requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) - - # The (beyond the enclosing Map) data is also mapped into the `if` block, so we - # have to consider that as well. - for iedge in state.in_edges(if_block): - if iedge.src is enclosing_map and (not iedge.data.is_empty()): - outside_desc = sdfg.arrays[iedge.data.data] - if isinstance(outside_desc, dace_data.View): - return False # Handle this case. - requiered_symbols |= outside_desc.used_symbols(True) + if iedge.src in relocatable_dataflow: + continue # Ignore internal connections, handled in subgraph. + elif iedge.data.is_empty(): + continue # Empty Memlets do not have symbols. + + if iedge.src is enclosing_map: + # Outside-Map data must be mapped. Here we only have to consider + # the symbols of the node and can ignore the symbols of the edge. + memlet_path = state.memlet_path(iedge) + node_to_map = memlet_path[0].src + else: + # The mapped node is inside the Map this means we replicate this + # edge thus in addition to the symbols of the data, we need the + # symbols needed by the edge. + node_to_map = iedge.src + requiered_symbols |= { + str(sym) for sym in iedge.data.used_symbols(True, edge=iedge) + } + + # Only AccessNodes can be mapped into `if_block`. + if not isinstance(node_to_map, dace_nodes.AccessNode): + return False + + # Add the symbols of the data. + requiered_symbols |= { + str(sym) for sym in sdfg.arrays[node_to_map.data].used_symbols(True) + } # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a # direct mapping. For example if there is a symbol `n` on the inside and outside @@ -621,36 +648,6 @@ def _check_for_data_and_symbol_conflicts( return True - def _find_branch_for( - self, - if_block: dace_nodes.NestedSDFG, - connector: str, - ) -> tuple[dace.SDFGState, dace_nodes.AccessNode]: - """ - Locates the branch and the AccessNode to where the dataflow should be relocated. - """ - inner_sdfg: dace.SDFG = if_block.sdfg - conditional_block: dace.sdfg.state.ConditionalBlock = next(iter(inner_sdfg.nodes())) - - # This will locate the state where the first AccessNode that refers to - # `connector` is found. Since `_partition_if_block()` makes sure that - # there is only one match this is okay. But it must be changed, if we - # lift this restriction. - for inner_state in conditional_block.all_states(): - connector_nodes: list[dace_nodes.AccessNode] = [ - dnode for dnode in inner_state.data_nodes() if dnode.data == connector - ] - if len(connector_nodes) == 0: - continue - break - else: - raise ValueError(f"Did not find a branch associated to '{connector}'.") - - assert isinstance(inner_state, dace.SDFGState) - assert inner_state.in_degree(connector_nodes[0]) == 0 - assert inner_state.out_degree(connector_nodes[0]) > 0 - return inner_state, connector_nodes[0] - def _has_if_block_relocatable_dataflow( self, sdfg: dace.SDFG, @@ -672,9 +669,10 @@ def _has_if_block_relocatable_dataflow( enclosing_map: The limiting node, i.e. the MapEntry of the Map `if_block` is located in. """ - if_block_spec = self._partition_if_block(upstream_if_block) + if_block_spec = self._partition_if_block(sdfg, upstream_if_block) if if_block_spec is None: return False + *classified_connectors, connector_usage_location = if_block_spec raw_relocatable_dataflow, non_relocatable_dataflow = ( { @@ -686,7 +684,7 @@ def _has_if_block_relocatable_dataflow( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in classified_connectors ) filtered_relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -694,9 +692,10 @@ def _has_if_block_relocatable_dataflow( if_block=upstream_if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) - if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()): + if len(filtered_relocatable_dataflow) == 0: return False return True @@ -708,19 +707,18 @@ def _filter_relocatable_dataflow( if_block: dace_nodes.NestedSDFG, raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]], non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], enclosing_map: dace_nodes.MapEntry, - ) -> dict[str, set[dace_nodes.Node]]: - """Partition the dependencies. + ) -> set[dace_nodes.Node]: + """Compute the final set of the relocatable nodes. The function expects the dataflow that is upstream of every connector of the `if_block`. The function will then scan the dataflow and compute - the parts that actually can be relocated and returns a `dict` mapping - every relocatable input connector to the set of nodes that can be relocated. - - The returned sets can include duplicate nodes, i.e. a node can be in the - dataflow of multiple connectors. The function that performs the actual - relocation (`_replicate_dataflow_into_branch`) will take care of that and - make sure that such nodes are only copied once. + the parts that actually can be relocated. It will then return a `set` + containing all nodes that can actually be relocated. If this set is empty + then nothing can be relocated. + Note that the returned `set` is in an unspecific order and before processing + should be ordered. Args: state: The state on which we operate. @@ -729,116 +727,115 @@ def _filter_relocatable_dataflow( that can be relocated, not yet filtered. non_relocatable_dataflow: The connectors and their associated dataflow that can not be relocated. + connector_usage_location: Maps a connector to the state and AccessNode + inside the if block. enclosing_map: The limiting node, i.e. the MapEntry of the Map where `if_block` is located in. """ - # Remove the parts of the dataflow that is unrelocatable. + # These are the nodes that can not be relocated anyway. all_non_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( lambda s1, s2: s1.union(s2), non_relocatable_dataflow.values(), set() ) - relocatable_dataflow = { - conn_name: rel_df.difference(all_non_relocatable_dataflow) - for conn_name, rel_df in raw_relocatable_dataflow.items() - } - - # Find the known_nodes for each branch - known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = dict() - for conn_name, rel_df in relocatable_dataflow.items(): - branch_state, _ = self._find_branch_for(if_block=if_block, connector=conn_name) - if branch_state not in known_nodes: - known_nodes[branch_state] = set() - known_nodes[branch_state].update(rel_df) - - multiple_df_nodes: set[dace_nodes.Node] = set() - # Find intersect of all known_nodes sets which are the nodes that are in the dataflow - # of multiple branches and thus doesn't make sense to relocate - for branch_state, known_nodes_set in known_nodes.items(): - for other_branch_state, other_known_nodes_set in known_nodes.items(): - if branch_state != other_branch_state: - multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) - - if multiple_df_nodes: - # Remove from the relocatable dataflow the nodes that appear in multiple branches - # as it doesn't make sense to relocate them and duplicate them in both branches. - relocatable_dataflow = { - conn_name: rel_df.difference(multiple_df_nodes) - for conn_name, rel_df in relocatable_dataflow.items() - } - # TODO(phimuell): If we operate outside of a Map we also have to make sure that - # the data is single use data, is not an AccessNode that refers to global - # memory nor is a source AccessNode. - def filter_nodes( - nodes_proposed_for_reloc: set[dace_nodes.Node], - ) -> set[dace_nodes.Node]: - has_been_updated = True - while has_been_updated: - has_been_updated = False - - for reloc_node in list(nodes_proposed_for_reloc): - # The node was already handled in a previous iteration. - if reloc_node not in nodes_proposed_for_reloc: - continue + # While we can relocate nodes that are needed by multiple connectors, we can + # not handle the case if they end up in multiple branches. + nodes_in_states: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) + for conn_name, rel_df in raw_relocatable_dataflow.items(): + nodes_in_states[connector_usage_location[conn_name][0]].update(rel_df) + state_nodes_sets = list(nodes_in_states.values()) # Order is unimportant here. + for i, state_nodes in enumerate(state_nodes_sets): + for j in range(i + 1, len(state_nodes_sets)): + all_non_relocatable_dataflow.update(state_nodes.intersection(state_nodes_sets[j])) + + # The dataflow that must happen before the `if_block`, i.e that is connected + # with it by an empty Memlet can not be reconnected. + for if_block_iedge in state.in_edges(if_block): + if if_block_iedge.src is enclosing_map: + continue + elif not if_block_iedge.data.is_empty(): + continue + all_non_relocatable_dataflow.update( + gtx_transformations.utils.find_upstream_nodes( + start=if_block_iedge.src, + state=state, + ) + ) + all_non_relocatable_dataflow.add(if_block_iedge.src) + + # Instead of scanning the nodes associated to each connector separately we will + # process all of them together. We do this because a node can be associated to + # multiple connectors and as such data dependencies can show up. We will, + # after the filtering distribute them back. + nodes_proposed_for_reloc: set[dace_nodes.Node] = functools.reduce( + lambda s1, s2: s1.union(s2), raw_relocatable_dataflow.values(), set() + ) - assert ( - state.in_degree(reloc_node) > 0 - ) # Because we are currently always inside a Map - - # If the node is needed by anything that is not also moved - # into the `if` body, then it has to remain outside. For that we - # have to pretend that `if_block` is also relocated. - if any( - oedge.dst not in nodes_proposed_for_reloc - for oedge in state.out_edges(reloc_node) - if oedge.dst is not if_block - ): - nodes_proposed_for_reloc.remove(reloc_node) - has_been_updated = True - continue + # Filtering out all nodes that can not be relocated anyway. + if all_non_relocatable_dataflow: + nodes_proposed_for_reloc.difference_update(all_non_relocatable_dataflow) - # We do not look at all incoming nodes, but have to ignore some of them. - # We ignore `enclosed_map` because it acts as boundary, and the node - # on the other side of it is mapped into the `if` body anyway. We - # ignore the AccessNodes because they will either be relocated into - # the `if` body or be mapped (remain outside but made accessible - # inside), thus their relocation state is of no concern for - # `reloc_node`. - non_mappable_incoming_nodes: set[dace_nodes.Node] = { - iedge.src - for iedge in state.in_edges(reloc_node) - if not ( - (iedge.src is enclosing_map) - or isinstance(iedge.src, dace_nodes.AccessNode) - ) - } - if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc): - # All nodes that can not be mapped into the `if` body are - # currently scheduled to be relocated, thus there is not - # problem. - pass + # TODO(phimuell): Better screening of empty Memlets. + has_been_updated = True + while has_been_updated: + has_been_updated = False - else: - # Only some of the non mappable nodes are selected to be - # moved inside the `if` body. This means that `reloc_node` - # can also not be moved because of its input dependencies. - # Since we can not relocate `reloc_node` this also implies - # that none of its input can. Thus we remove them from - # `nodes_proposed_for_reloc`. - nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes) - nodes_proposed_for_reloc.remove(reloc_node) - has_been_updated = True - - return nodes_proposed_for_reloc - - return { - conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() - } + for reloc_node in list(nodes_proposed_for_reloc): + # The node was already removed in a previous iteration. + if reloc_node not in nodes_proposed_for_reloc: + continue + + # Because we are currently always inside a Map + assert state.in_degree(reloc_node) > 0 + + # If the node is needed by anything that is not also moved + # into the `if` body, then it has to remain outside. For that we + # have to pretend that `if_block` is also relocated. + if any( + oedge.dst not in nodes_proposed_for_reloc + for oedge in state.out_edges(reloc_node) + if oedge.dst is not if_block + ): + nodes_proposed_for_reloc.remove(reloc_node) + has_been_updated = True + continue + + # We do not look at incoming edges that comes from nodes that are not + # mappable, i.e. AccessNodes. In addition to AccessNodes we also + # ignore `enclosing_map` because it acts as a boundary anyway and + # on its other side is an AccessNode anyway. + non_mappable_incoming_nodes: set[dace_nodes.Node] = { + iedge.src + for iedge in state.in_edges(reloc_node) + if not ( + (iedge.src is enclosing_map) or isinstance(iedge.src, dace_nodes.AccessNode) + ) + } + if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc): + # All nodes that can not be mapped into the `if` body are + # currently scheduled to be relocated, thus there is no + # problem. + pass + + else: + # Only some of the non mappable nodes are selected to be moved + # inside the `if` body. This means that `reloc_node` can also + # not be moved because of its input dependencies. Since we can + # not relocate `reloc_node` this also implies that none of its + # inputs either. + nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes) + nodes_proposed_for_reloc.remove(reloc_node) + has_been_updated = True + + return nodes_proposed_for_reloc def _partition_if_block( self, + sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, - ) -> Optional[tuple[set[str], set[str]]]: + ) -> Optional[ + tuple[list[str], list[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] + ]: """Check if `if_block` can be processed and partition the input connectors. The function will check if `if_block` has the right structure, i.e. if it is @@ -847,21 +844,31 @@ def _partition_if_block( be inlined into the `if_block` and which can not. Returns: - If `if_block` is unsuitable the function will return `None`. - If `if_block` meets the structural requirements the function will return - two sets of strings. The first set contains the connectors that can be - relocated and the second one of the conditions that can not be relocated. + If `if_block` is unsuitable the function will return `None`. In case the + `if_block` is suitable a `tuple` of length three is returned. + The first element is a `list`, which is never empty, containing all + input connectors that can be relocated. The list is sorted in a stable + order. The second element is a list containing all input connectors that + can not be relocated, it can be empty and is not in a particular order. + The third element is a `dict` that maps connectors to a pair containing + the state (inside the nested SDFG) and the only `AccessNode` that refers + to that connector. + It is important that only the first element of the `tuple` has a guaranteed + order. """ - # TODO(phimuell): Change the return type to `tuple[list[str], list[str]]` and sort the connectors, such that the operation is deterministic. - # There shall only be one output and three inputs with given names. if len(if_block.out_connectors.keys()) == 0: return None - # These are all the output names. + input_names: set[str] = set(if_block.in_connectors.keys()) output_names: set[str] = set(if_block.out_connectors.keys()) - # We require that the nested SDFG contains a single node, which is a - # `ConditionalBlock` containing two branches. + # If data is used as input and output we ignore it. + # TODO(phimuell): Think if this case can be handled. + input_names.difference_update(output_names) + if len(input_names) == 0: + return None + + # We require that the nested SDFG contains a single node, which is a `ConditionalBlock`. inner_sdfg: dace.SDFG = if_block.sdfg if inner_sdfg.number_of_nodes() != 1: return None @@ -869,47 +876,70 @@ def _partition_if_block( if not isinstance(inner_if_block, dace.sdfg.state.ConditionalBlock): return None - # Defining it outside will ensure that there is only one AccessNode for every - # inconnector, which is something `_find_branch_for()` relies on. - reference_count: dict[str, int] = {conn_name: 0 for conn_name in if_block.in_connectors} + # Mapping between the connector and the inner access node. + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} + + # This is the dataflow that can not be relocated. + non_relocatable_connectors: set[str] = set() + + # Now inspect all states. + for _, if_branch in inner_if_block.branches: + for inner_state in if_branch.all_states(): + for dnode in inner_state.data_nodes(): + node_data = dnode.data + + # Check if we can skip the data. + if node_data in non_relocatable_connectors: + continue + elif node_data in output_names: + continue + elif dnode.desc(inner_sdfg).transient: + continue + assert node_data in input_names + + if node_data in connector_usage_location: + # There are multiple AccessNodes referring to the same connector + # which is currently not supported. In theory they could appear + # more, but then we would have to replicate the dataflow to + # different locations which is not supported. We allow such + # situations but consider the connector non relocatable. + connector_usage_location.pop(node_data) + non_relocatable_connectors.add(node_data) + + elif inner_state.in_degree(dnode) != 0: + # The node is also written to, allowed by SDFG grammar, but we + # do not allow it. + non_relocatable_connectors.add(node_data) - for _, branch in inner_if_block.branches: - output_count: dict[str, int] = {conn_name: 0 for conn_name in output_names} - for inner_state in branch.all_states(): - assert isinstance(inner_state, dace.SDFGState) - for node in inner_state.nodes(): - if not isinstance(node, dace_nodes.AccessNode): - return None - if node.data in reference_count: - reference_count[node.data] += 1 - exp_in_deg, exp_out_deg = 0, 1 - elif node.data in output_count: - output_count[node.data] += 1 - exp_in_deg, exp_out_deg = 1, 0 else: + # This is a proper input connector node. + connector_usage_location[node_data] = (inner_state, dnode) + + # If all input connectors were classified as non relocatable + # then the partition does not exist. + if len(non_relocatable_connectors) == len(input_names): + assert non_relocatable_connectors == input_names return None - if inner_state.in_degree(node) != exp_in_deg: - return None - if inner_state.out_degree(node) != exp_out_deg: - return None - # The connectors that can be pulled inside must appear exactly once. - # In theory they could appear more, but then we would have to replicate - # the dataflow to different locations which is not supported. - # So the ones that can be relocated were found exactly once. Zero would - # mean they can not be relocated and more than one means that we do not - # support it yet. - relocatable_connectors = { - conn_name for conn_name, conn_count in reference_count.items() if conn_count == 1 - } - non_relocatable_connectors = { - conn_name - for conn_name in reference_count.keys() - if conn_name not in relocatable_connectors - } + # There is nothing to relocate. + if len(connector_usage_location) == 0: + return None + + # In addition to the non relocatable connectors that were found above, we also + # mark all connectors that were not found as non relocatable. + non_relocatable_connectors.update( + conn for conn in input_names if conn not in connector_usage_location + ) + # We require that at least one non relocatable dataflow is there, this is for + # the condition. This is not strictly needed, as it could also be passed as + # a symbol, but currently the lowering does not do this and we keep it as + # a sanity check. if len(non_relocatable_connectors) == 0: return None - if len(relocatable_connectors) == 0: - return None - return relocatable_connectors, non_relocatable_connectors + + # We only guarantee that `relocatable_connectors` has an stable order, + # everything else has no guaranteed order, even `connector_usage_location`. + relocatable_connectors = sorted(connector_usage_location.keys()) + + return relocatable_connectors, list(non_relocatable_connectors), connector_usage_location diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 067ff82cac..7ef95afee9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -953,15 +953,20 @@ def test_if_mover_dependent_branch_4(): _perform_test(sdfg, explected_applies=1) - # # Examine the structure of the SDFG. + # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) assert len(sdfg.arrays) == len(top_ac) + assert all(state.out_degree(ac) == 1 for ac in [s1, c1]) + assert all(oedge.dst_conn == "__arg4" for oedge in state.out_edges(s1)) top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) assert len(top_tlet) == 2 assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + all_mapped_in_data = [iedge.data.data for iedge in state.in_edges(if_block)] + assert len(all_mapped_in_data) == len(set(all_mapped_in_data)) + inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( if_block.sdfg, dace_nodes.AccessNode, True ) @@ -970,12 +975,12 @@ def test_if_mover_dependent_branch_4(): .union(input_names) .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) - expected_data.difference_update(["c1", "c", "d", "f", "s"]) + expected_data.difference_update(["c1", "c", "d", "f", "s", "s1"]) assert expected_data == {ac.data for ac in inner_ac} - assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 + assert len([ac for ac in inner_ac if ac.data == "__arg4"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 - assert len(expected_data) + 3 == len(inner_ac) + assert len(expected_data) + 4 == len(inner_ac) assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) @@ -986,9 +991,6 @@ def test_if_mover_dependent_branch_4(): assert {tlet.label for tlet in inner_tlet} == expected_tlet -@pytest.mark.xfail( - reason="This test is currently expected to fail. For the explanation see: https://github.com/GridTools/gt4py/pull/2514#discussion_r2906948120" -) def test_if_mover_dependent_branch_5(): """ Essentially tests the following situation: @@ -1134,17 +1136,22 @@ def test_if_mover_dependent_branch_5(): mx.add_out_connector("OUT_f") sdfg.validate() - _perform_test(sdfg, explected_applies=2) + _perform_test(sdfg, explected_applies=1) - # # Examine the structure of the SDFG. + # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) assert len(sdfg.arrays) == len(top_ac) + assert all(state.out_degree(ac) == 1 for ac in [s1, c1]) + assert all(oedge.dst_conn == "__arg4" for oedge in state.out_edges(s1)) top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) assert len(top_tlet) == 2 assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + all_mapped_in_data = [iedge.data.data for iedge in state.in_edges(if_block)] + assert len(all_mapped_in_data) == len(set(all_mapped_in_data)) + inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( if_block.sdfg, dace_nodes.AccessNode, True ) @@ -1153,18 +1160,26 @@ def test_if_mover_dependent_branch_5(): .union(input_names) .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) - expected_data.difference_update(["c1", "c", "d", "f", "s"]) + expected_data.difference_update(["c1", "c", "d", "f", "s", "s1"]) assert expected_data == {ac.data for ac in inner_ac} - assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 + assert len([ac for ac in inner_ac if ac.data == "__arg4"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 - assert len(expected_data) + 3 == len(inner_ac) + assert len(expected_data) + 4 == len(inner_ac) assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) - assert len(inner_tlet) == 5 + assert len(inner_tlet) == 6 expected_tlet = { - tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2, tasklet_node_reuse] + tlet.label + for tlet in [ + tasklet_a1, + tasklet_a2, + tasklet_a2a, + tasklet_b1, + tasklet_b2, + tasklet_node_reuse, + ] } assert {tlet.label for tlet in inner_tlet} == expected_tlet @@ -1695,8 +1710,8 @@ def test_if_mover_symbolic_tasklet(): sdfg, explected_applies=1, ) - expected_symb = {"symbol_1", "symbol_2"} + expected_symb = {"symbol_1", "symbol_2"} assert if_block.sdfg.symbols.keys() == expected_symb.union(["__i"]) assert all(if_block.sdfg.symbols[sym] == dace.float64 for sym in expected_symb) assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64} @@ -1953,3 +1968,139 @@ def test_if_mover_symbol_aliasing(): sdfg=sdfg, explected_applies=0, ) + + +def test_if_mover_slice_input(): + def _make_nested_sdfg(cond_name: str, iter_name: str) -> dace.SDFG: + sdfg = dace.SDFG("If_block") + + sdfg.add_scalar("arg1", dtype=dace.float64, transient=False) + sdfg.add_scalar("out", dtype=dace.float64, transient=False) + sdfg.add_scalar(cond_name, dtype=dace.bool_, transient=False) + sdfg.add_array("arg2", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_symbol(iter_name, stype=dace.int32) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tstate.add_edge( + tstate.add_access("arg1"), + None, + tstate.add_access("out"), + None, + dace.Memlet("arg1[0] -> [0]"), + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + f_tasklet = fstate.add_tasklet( + "f_tasklet", inputs={"__in"}, outputs={"__out"}, code="__out = __in + 1.0" + ) + fstate.add_edge( + fstate.add_access("arg2"), None, f_tasklet, "__in", dace.Memlet(f"arg2[{iter_name}]") + ) + fstate.add_edge(f_tasklet, "__out", fstate.add_access("out"), None, dace.Memlet("out[0]")) + + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) + sdfg.add_node(if_region, is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) + + sdfg.validate() + return sdfg + + def _make_outer_sdfg( + cond_name: str, iter_name: str + ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_slicing")) + state = sdfg.add_state(is_start_block=True) + + # Inputs + input_names = list("abcd") + for name in input_names: + sdfg.add_array( + name, + shape=((10, 10) if name.startswith("b") else (10,)), + dtype=dace.float64, + transient=False, + ) + + # Temporaries + temporary_names = ["a1", "c1"] + for name in temporary_names: + sdfg.add_scalar( + name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True + ) + a1, c1 = (state.add_access(name) for name in temporary_names) + + me, mx = state.add_map("map", ndrange={iter_name: "0:10"}) + for name in input_names[:-1]: + state.add_edge( + state.add_access(name), + None, + me, + f"IN_{name}", + dace.Memlet(data=name, subset=("0:10, 0:10" if name == "b" else "0:10")), + ) + me.add_scope_connectors(name) + + state.add_edge( + mx, "OUT_d", state.add_access("d"), None, dace.Memlet(data="d", subset="0:10") + ) + mx.add_scope_connectors("d") + + # First branch. + tasklet_a1 = state.add_tasklet( + "tasklet_a1", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(me, "OUT_a", tasklet_a1, "__in1", dace.Memlet(f"a[{iter_name}]")) + state.add_edge( + me, "OUT_b", tasklet_a1, "__in2", dace.Memlet(f"b[{iter_name}, {iter_name}]") + ) + state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) + + # Second branch + # There is nothing. + + # Condition + tasklet_c1 = state.add_tasklet( + "tasklet_c1", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in > 0.5", + ) + state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet(f"c[{iter_name}]")) + state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]")) + + # Nested SDFG + nsdfg = state.add_nested_sdfg( + sdfg=_make_nested_sdfg(cond_name=cond_name, iter_name=iter_name), + inputs={"arg1", "arg2", cond_name}, + outputs={"out"}, + symbol_mapping={iter_name: iter_name}, + ) + state.add_edge(a1, None, nsdfg, "arg1", dace.Memlet("a1[0]")) + state.add_edge(me, "OUT_b", nsdfg, "arg2", dace.Memlet(f"b[{iter_name}, 0:10]")) + state.add_edge(c1, None, nsdfg, cond_name, dace.Memlet("c1[0]")) + state.add_edge(nsdfg, "out", mx, "IN_d", dace.Memlet(f"d[{iter_name}]")) + + sdfg.validate() + return sdfg, state, nsdfg + + iter_name = "__i" + cond_name = "cond" + + sdfg, state, nsdfg = _make_outer_sdfg(cond_name=cond_name, iter_name=iter_name) + + _perform_test( + sdfg=sdfg, + explected_applies=1, + ) + + assert False, "Add structural checks." + assert False, ( + "Add checks where `b` is copied fully inside the Map scope and then sliced into the `if_block`." + )