Skip to content

Commit 9114985

Browse files
committed
Squashed commit of the following:
commit a26f6a2 Merge: 9c2c36e afd0efe Author: Philip Müller <[email protected]> Date: Wed Jul 2 15:19:17 2025 +0200 Merge branch 'main' into unified_state_fission commit 9c2c36e Merge: b5a838e 96a1f0b Author: Philipp Schaad <[email protected]> Date: Sat Jun 28 09:34:57 2025 +0200 Merge branch 'main' into unified_state_fission commit b5a838e Author: Philip Mueller <[email protected]> Date: Thu Jun 26 07:43:17 2025 +0200 Updated. commit b2c4288 Author: Philip Mueller <[email protected]> Date: Wed Jun 25 15:24:43 2025 +0200 Fixed an issue. commit 17dc202 Author: Philip Mueller <[email protected]> Date: Wed Jun 25 14:41:29 2025 +0200 Now view nodes in that thing are handled. commit 1129d10 Author: Philip Mueller <[email protected]> Date: Wed Jun 25 14:34:46 2025 +0200 Fixed a bug in `get_all_view_edges()` it was possible that the function went back and forth. Before the function was deciding in which direction the edge goes for every new edge. Thus it one point it could get stuck and go back and forth all the time. The new implementation decides the direction once and then uses it until the end. commit 317df13 Author: Philip Mueller <[email protected]> Date: Wed Jun 25 14:15:05 2025 +0200 Addessed Alexnick's comments. commit 38439a3 Author: Philip Mueller <[email protected]> Date: Tue Jun 24 11:39:29 2025 +0200 Added a test for real nested scope thing. commit 5e6acfd Author: Philip Mueller <[email protected]> Date: Tue Jun 24 11:28:18 2025 +0200 Made the map test slightly difficulter. commit d4dbe0f Author: Philip Mueller <[email protected]> Date: Tue Jun 24 11:18:26 2025 +0200 Added the handling of non global nodes. commit 0c412ff Author: Philip Mueller <[email protected]> Date: Tue Jun 24 09:57:11 2025 +0200 A fucntion called `entry_node()` should not contain `return self.exit_node(...)` it does not make any sense. commit dbb6325 Author: Philip Mueller <[email protected]> Date: Tue Jun 24 09:36:29 2025 +0200 Updated the description of the function. commit cc4c68f Author: Philip Mueller <[email protected]> Date: Tue Jun 24 08:55:56 2025 +0200 Solved the empty Memlet issue. commit 3216b87 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 16:05:02 2025 +0200 Added a new test regarding empty memlets. However, I am not sure if it is handled yet, I kind of think that it is handled yet. I added it there to ensure that it is also handled in teh future. commit dff5e1a Author: Philip Mueller <[email protected]> Date: Mon Jun 23 15:52:15 2025 +0200 Added a test for the handling of empty memlets. However, this case is not yet handled. commit b35692e Author: Philip Mueller <[email protected]> Date: Mon Jun 23 15:22:51 2025 +0200 Renamed some functions. commit 5d5943e Author: Philip Mueller <[email protected]> Date: Mon Jun 23 15:14:08 2025 +0200 Added a test case when there are views involved. commit 0ed6a6e Merge: f0dc4c5 186d21d Author: Philip Mueller <[email protected]> Date: Mon Jun 23 14:42:00 2025 +0200 Merge remote-tracking branch 'spcl/main' into unified_state_fission commit f0dc4c5 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 14:40:18 2025 +0200 Added a test for multi write case. commit 6c750e8 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 13:46:06 2025 +0200 Added a new unit test. commit 495d64e Author: Philip Mueller <[email protected]> Date: Mon Jun 23 13:40:09 2025 +0200 Handled a special case. commit 947ad91 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 13:03:16 2025 +0200 Fixed an issue in the map fusion test, caused by the introduction of the new utility header. commit bd85d43 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 13:02:42 2025 +0200 The "make data" function of the test utility no longer generates data for `__return*`. commit c7b7ec7 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 12:57:26 2025 +0200 Added new tests. commit 3e33fab Author: Philip Mueller <[email protected]> Date: Mon Jun 23 11:46:41 2025 +0200 Fixed an import error in the fpg map fusion test. commit ff786bb Author: Philip Mueller <[email protected]> Date: Mon Jun 23 11:43:43 2025 +0200 Added a new kind of test. commit 64e6ede Author: Philip Mueller <[email protected]> Date: Mon Jun 23 11:11:27 2025 +0200 Added a first batch of tests. commit a814670 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 11:10:19 2025 +0200 Fixed a strange bug. commit 0fb1eef Author: Philip Mueller <[email protected]> Date: Mon Jun 23 10:38:30 2025 +0200 Updated how the state spliter handles isolated nodes. commit 8d5b4fd Author: Philip Mueller <[email protected]> Date: Mon Jun 23 09:49:31 2025 +0200 Fixed a bug. commit 21c3a03 Author: Philip Mueller <[email protected]> Date: Mon Jun 23 09:24:13 2025 +0200 Created a file with test helpers. commit ee562eb Author: Philip Mueller <[email protected]> Date: Mon Jun 23 09:08:45 2025 +0200 The splitting is now more compatible with what it should be so let's keep that. commit 98a8b25 Author: Philip Mueller <[email protected]> Date: Fri Jun 13 14:15:01 2025 +0200 Created a new state fission function. It is designed in a similar way than `isolate_nested_sdfg()` function, i.e. it preserves the number of writes to nodes. The change here is, that it might happen that some consumer of `subgraph` end up in the first state. However, The old version also had that problem, because I could pass a simple Tasklet as subset and then one would end up with an invalid function. commit 799b065 Author: Philip Mueller <[email protected]> Date: Fri Jun 13 11:25:58 2025 +0200 Reverted some of my changes to be bug comatible with the original implementation. commit 8edc84e Author: Philip Mueller <[email protected]> Date: Fri Jun 13 11:22:09 2025 +0200 Made some annotations. commit f7b0110 Author: Philip Mueller <[email protected]> Date: Fri Jun 13 11:05:00 2025 +0200 Relocated teh `state_fission_after()` function to teh only location where it was used. There was only one place where this function was used. Given its sorry state (no unit test, no doc string and strange behaviour), I decided to remove it from the helper and put it there.
1 parent 99b2f12 commit 9114985

10 files changed

Lines changed: 1292 additions & 212 deletions

File tree

dace/sdfg/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ def data_nodes(self) -> List[nd.AccessNode]:
10861086
def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]:
10871087
for block in self.nodes():
10881088
if node in block.nodes():
1089-
return block.exit_node(node)
1089+
return block.entry_node(node)
10901090
return None
10911091

10921092
def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]:

dace/sdfg/utils.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,64 @@ def dfs_topological_sort(G, sources=None, condition=None, reverse=False):
174174
stack.pop()
175175

176176

177+
def _find_nodes_impl(
178+
node_to_start: Node,
179+
state: SDFGState,
180+
forward: bool,
181+
seen: Optional[Set[Node]],
182+
) -> Set[Node]:
183+
to_scan: List[Node] = [node_to_start]
184+
scanned_nodes: Set[Node] = set() if seen is None else seen
185+
if forward:
186+
get_edges = state.out_edges
187+
get_node = lambda e: e.dst
188+
else:
189+
get_edges = state.in_edges
190+
get_node = lambda e: e.src
191+
while len(to_scan) != 0:
192+
node_to_scan = to_scan.pop()
193+
if node_to_scan in scanned_nodes:
194+
continue
195+
to_scan.extend(get_node(edge) for edge in get_edges(node_to_scan) if get_node(edge) not in scanned_nodes)
196+
scanned_nodes.add(node_to_scan)
197+
return scanned_nodes
198+
199+
200+
def find_downstream_nodes(node_to_start: Node, state: SDFGState, seen: Optional[Set[Node]] = None) -> Set[Node]:
201+
"""Find all downstream nodes of `node_to_start`.
202+
203+
The function will explore the state, similar to a BFS, just that the order in which the nodes of
204+
the dataflow is explored is unspecific. It is possible to pass a `set` of nodes that should be
205+
considered as already visited. It is important that the function will return the set of found
206+
nodes. In case `seen` was passed that `set` will be updated in place and be returned.
207+
208+
:param node_to_start: Where to start the exploration of the state.
209+
:param state: The state on which we operate on.
210+
:param seen: The set of already seen nodes.
211+
212+
:note: See also `find_upstream_nodes()` in case the dataflow should be explored in the reverse direction.
213+
"""
214+
return _find_nodes_impl(node_to_start=node_to_start, state=state, seen=seen, forward=True)
215+
216+
217+
def find_upstream_nodes(node_to_start: Node, state: SDFGState, seen: Optional[Set[Node]] = None) -> Set[Node]:
218+
"""Find all upstream nodes of `node_to_start`.
219+
220+
The function will explore the state, similar to a BFS, just that the order in which the nodes of
221+
the dataflow is explored is unspecific. It is possible to pass a `set` of nodes that should be
222+
considered as already visited. It is important that the function will return the set of found
223+
nodes. In case `seen` was passed that `set` will be updated in place and be returned.
224+
225+
The main difference to `find_downstream_nodes()` is that the dataflow is traversed in reverse
226+
order or "against the flow".
227+
228+
:param node_to_start: Where to start the exploration of the state.
229+
:param state: The state on which we operate on.
230+
:param seen: The set of already seen nodes.
231+
"""
232+
return _find_nodes_impl(node_to_start=node_to_start, state=state, seen=seen, forward=False)
233+
234+
177235
class StopTraversal(Exception):
178236
"""
179237
Special exception that stops DFS conditional traversal beyond the current node.
@@ -807,24 +865,32 @@ def get_all_view_edges(state: SDFGState, view: nd.AccessNode) -> List[gr.MultiCo
807865
if existent, else None
808866
"""
809867
sdfg = state.parent
810-
node = view
811-
desc = sdfg.arrays[node.data]
868+
previous_node = view
812869
result = []
870+
871+
desc = sdfg.arrays[previous_node.data]
872+
forward = None
813873
while isinstance(desc, dt.View):
814-
edge = get_view_edge(state, node)
874+
edge = get_view_edge(state, previous_node)
815875
if edge is None:
816876
break
817-
old_node = node
818-
if edge.dst is view:
819-
node = edge.src
877+
878+
if forward is None:
879+
forward = edge.src is previous_node
880+
881+
if forward:
882+
next_node = edge.dst
820883
else:
821-
node = edge.dst
822-
if node is old_node:
884+
next_node = edge.src
885+
886+
if previous_node is next_node:
823887
break
824-
if not isinstance(node, nd.AccessNode):
888+
if not isinstance(next_node, nd.AccessNode):
825889
break
826-
desc = sdfg.arrays[node.data]
890+
desc = sdfg.arrays[next_node.data]
827891
result.append(edge)
892+
previous_node = next_node
893+
828894
return result
829895

830896

dace/transformation/dataflow/wcr_conversion.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import re
66
import copy
7-
from dace import nodes, dtypes, Memlet
7+
from dace import nodes, dtypes, Memlet, data
88
from dace.frontend.python import astutils
99
from dace.transformation import transformation
1010
from dace.sdfg import utils as sdutil
@@ -151,7 +151,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
151151

152152
# If state fission is necessary to keep semantics, do it first
153153
if state.in_degree(input) > 0:
154-
new_state = helpers.state_fission_after(state, tasklet)
154+
new_state = self.isolate_tasklet(state)
155155
else:
156156
new_state = state
157157

@@ -270,6 +270,87 @@ def apply(self, state: SDFGState, sdfg: SDFG):
270270
# At this point we are leading to an access node again and can
271271
# traverse further up
272272

273+
def isolate_tasklet(
274+
self,
275+
state: SDFGState,
276+
) -> SDFGState:
277+
tlet: nodes.Tasklet = self.tasklet
278+
newstate = state.parent_graph.add_state_after(state)
279+
280+
# Bookkeeping
281+
nodes_to_move = set([tlet])
282+
boundary_nodes = set()
283+
orig_edges = set()
284+
285+
for edge in state.in_edges(tlet):
286+
for e in state.memlet_path(edge):
287+
nodes_to_move.add(e.src)
288+
orig_edges.add(e)
289+
if isinstance(e.src, nodes.AccessNode) and isinstance(e.src.desc(sdfg), data.View):
290+
assert state.in_degree(e.src) > 0
291+
view_edges = sdutil.get_all_view_edges(state, e.src)
292+
for edge in view_edges:
293+
nodes_to_move.add(edge.src)
294+
orig_edges.add(edge)
295+
296+
# Find all consumer nodes of `tlet`.
297+
for edge in state.edge_bfs(tlet):
298+
nodes_to_move.add(edge.dst)
299+
orig_edges.add(edge)
300+
301+
# If a consumer is not an AccessNode we also have to relocate its dependencies.
302+
if not isinstance(edge.dst, nodes.AccessNode):
303+
for iedge in state.in_edges(edge.dst):
304+
if iedge == edge:
305+
continue
306+
for e in state.memlet_path(iedge):
307+
nodes_to_move.add(e.src)
308+
orig_edges.add(e)
309+
310+
# Define boundary nodes
311+
for node in nodes_to_move:
312+
if isinstance(node, nodes.AccessNode):
313+
for iedge in state.in_edges(node):
314+
if iedge.src not in nodes_to_move:
315+
boundary_nodes.add(node)
316+
break
317+
if node in boundary_nodes:
318+
continue
319+
for oedge in state.out_edges(node):
320+
if oedge.dst not in nodes_to_move:
321+
boundary_nodes.add(node)
322+
break
323+
324+
# Duplicate boundary nodes
325+
new_nodes = {}
326+
for node in boundary_nodes:
327+
node_ = copy.deepcopy(node)
328+
state.add_node(node_)
329+
new_nodes[node] = node_
330+
331+
for edge in state.edges():
332+
if edge.src in boundary_nodes and edge.dst in boundary_nodes:
333+
state.add_edge(new_nodes[edge.src], edge.src_conn, new_nodes[edge.dst], edge.dst_conn,
334+
copy.deepcopy(edge.data))
335+
elif edge.src in boundary_nodes:
336+
state.add_edge(new_nodes[edge.src], edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data))
337+
elif edge.dst in boundary_nodes:
338+
state.add_edge(edge.src, edge.src_conn, new_nodes[edge.dst], edge.dst_conn, copy.deepcopy(edge.data))
339+
340+
state.remove_nodes_from(nodes_to_move)
341+
342+
# Set the new parent state
343+
# TODO: Note sure if `add_node()` does it on its own?
344+
for node in nodes_to_move:
345+
if isinstance(node, nodes.NestedSDFG):
346+
node.sdfg.parent = newstate
347+
348+
newstate.add_nodes_from(nodes_to_move)
349+
for e in orig_edges:
350+
newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)
351+
352+
return newstate
353+
273354

274355
class WCRToAugAssign(transformation.SingleStateTransformation):
275356
"""

0 commit comments

Comments
 (0)