Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions backends/arm/test/misc/test_partition_cycle_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.tosa.partitioner import (
_find_connected_components,
_validate_partition,
)


def _build_linear_graph():
"""Build a linear graph: x -> a -> b -> c -> output.

Returns the graph and nodes (x, a, b, c, output).
"""
graph = torch.fx.Graph()
x = graph.placeholder("x")
a = graph.call_function(torch.add, (x, x))
b = graph.call_function(torch.mul, (a, a))
c = graph.call_function(torch.sub, (b, b))
output = graph.output(c)
return graph, (x, a, b, c, output)


class TestValidatePartition(unittest.TestCase):
def test_contiguous_partition_is_valid(self):
"""A contiguous slice of a linear graph has no cycle."""
_, (_, a, b, _, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a, b}))

def test_non_contiguous_partition_has_cycle(self):
"""Nodes {a, c} with b in between create a cycle: extracting a and c
would force b to depend on a (inside) and c to depend on b (outside),
while c is also inside.
"""
_, (_, a, _, c, _) = _build_linear_graph()
self.assertFalse(_validate_partition({a, c}))

def test_single_node_is_valid(self):
_, (_, a, _, _, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a}))

def test_full_graph_interior_is_valid(self):
"""All interior nodes form a valid partition."""
_, (_, a, b, c, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a, b, c}))


class TestFindConnectedComponents(unittest.TestCase):
def test_single_component(self):
_, (_, a, b, _, _) = _build_linear_graph()
components = _find_connected_components({a, b})
self.assertEqual(len(components), 1)
self.assertEqual(components[0], {a, b})

def test_disconnected_components(self):
"""Nodes {a, c} with b not in the set form two components."""
_, (_, a, _, c, _) = _build_linear_graph()
components = _find_connected_components({a, c})
self.assertEqual(len(components), 2)
component_sets = [frozenset(c) for c in components]
self.assertIn(frozenset({a}), component_sets)
self.assertIn(frozenset({c}), component_sets)

def test_empty_set(self):
components = _find_connected_components(set())
self.assertEqual(len(components), 0)

def test_branching_graph(self):
"""Graph with a fork: x -> a -> b, x -> a -> c. {b, c} are disconnected
when a is excluded."""
graph = torch.fx.Graph()
x = graph.placeholder("x")
a = graph.call_function(torch.add, (x, x))
b = graph.call_function(torch.mul, (a, a))
c = graph.call_function(torch.sub, (a, a))
_ = graph.output((b, c))

components = _find_connected_components({b, c})
self.assertEqual(len(components), 2)

# With a included, all three form one component
components = _find_connected_components({a, b, c})
self.assertEqual(len(components), 1)


if __name__ == "__main__":
unittest.main()
168 changes: 139 additions & 29 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import operator
from collections import deque
from itertools import count
from pathlib import Path
from typing import Callable, cast, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -155,6 +156,77 @@ def reject_partition(
)


def _validate_partition(nodes: set[torch.fx.Node]) -> bool:
"""Check whether a set of nodes can be extracted as a subgraph without
cycles.

Perform a BFS from the external users of partition nodes. If any node
reached by BFS is itself inside the partition, then extracting the
partition would create a dependency cycle in the remaining graph.

Args:
nodes: The set of FX nodes that form the partition.

Returns:
True if the partition is valid (no cycles), False otherwise.

"""
outputs: list[torch.fx.Node] = []
for node in nodes:
for user in node.users:
if user not in nodes:
outputs.append(user)

visited: set[torch.fx.Node] = set()
queue = deque(outputs)
while queue:
current = queue.popleft()
if current in visited:
continue
visited.add(current)
if current in nodes:
return False
for user in current.users:
if user not in visited:
queue.append(user)
return True


def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]:
"""Find connected components in a set of nodes treating edges as undirected.

Two nodes are connected if one is an input or user of the other and both
are in ``nodes``.

Args:
nodes: The node set to partition into components.

Returns:
A list of disjoint node sets, one per connected component.

"""
remaining = set(nodes)
components: list[set[torch.fx.Node]] = []
while remaining:
seed = next(iter(remaining))
component: set[torch.fx.Node] = set()
queue = deque([seed])
while queue:
node = queue.popleft()
if node in component or node not in remaining:
continue
component.add(node)
for inp in node.all_input_nodes:
if inp in remaining and inp not in component:
queue.append(inp)
for user in node.users:
if user in remaining and user not in component:
queue.append(user)
remaining -= component
components.append(component)
return components


class TOSAPartitioner(Partitioner):
"""Partition an exported program into TOSA-delegable subgraphs.

Expand Down Expand Up @@ -381,36 +453,74 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
detag_first_fp_node=False,
)

if self._partition_has_invalid_uint8(partition, tag):
reject_partition(
"Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.",
partition,
reporter,
)
tags.remove(tag)
continue

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
is_nocompute_partition = all(
_is_noop_clone(node)
or _is_noop_alias_copy(node)
or _is_noop_expand(node)
or _is_noop_detach_copy(node)
or _is_noop_to_dim_order_copy(node)
or _is_noop_squeeze(node)
or _is_view_copy(node)
or _is_noop_as_strided_copy(node)
or node.target in Q_OPS
or node.target in DQ_OPS
for node in partition.nodes
)
if is_nocompute_partition:
reject_partition(
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
partition,
reporter,
if self.tosa_spec.support_integer() and not self.tosa_spec.support_float():
# After de-tagging, the remaining tagged nodes may form
# dependency cycles. This happens when models contain complex
# attention blocks (e.g. MobileViT) where Q/DQ nodes act as
# bridges between partition segments. Detect such cycles and
# split the partition into valid connected components.
surviving = {n for n in partition.nodes if is_partitioned(n, tag)}
if surviving and not _validate_partition(surviving):
components = _find_connected_components(surviving)
logger.info(
f"Partition {tag} has dependency cycle after Q/DQ "
f"de-tagging. Splitting into {len(components)} "
f"sub-partition(s)."
)
# Remove the original tag from all nodes
for node in surviving:
del node.meta["delegation_tag"]
tags.remove(tag)
# Re-tag each connected component as a new partition
for component in components:
new_tag = f"tag{next(tag_iterator)}"
tags.add(new_tag)
for node in component:
node.meta["delegation_tag"] = new_tag

# After potential cycle-splitting the original tag may have been
# replaced by one or more sub-tags. Collect every active tag that
# still has nodes in this partition so checks below apply to each
# resulting sub-partition.
active_tag_nodes: dict[str, list[torch.fx.Node]] = {}
for node in partition.nodes:
node_tag = node.meta.get("delegation_tag")
if node_tag is not None and node_tag in tags:
active_tag_nodes.setdefault(node_tag, []).append(node)

for active_tag, nodes in active_tag_nodes.items():
if self._partition_has_invalid_uint8(partition, active_tag):
reject_partition(
"Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.",
Partition(nodes=nodes),
reporter,
)
if active_tag in tags:
tags.remove(active_tag)
continue

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation.
is_nocompute_partition = all(
_is_noop_clone(node)
or _is_noop_alias_copy(node)
or _is_noop_expand(node)
or _is_noop_detach_copy(node)
or _is_noop_to_dim_order_copy(node)
or _is_noop_squeeze(node)
or _is_view_copy(node)
or _is_noop_as_strided_copy(node)
or node.target in Q_OPS
or node.target in DQ_OPS
for node in nodes
)
tags.remove(tag)
if is_nocompute_partition:
reject_partition(
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
Partition(nodes=nodes),
reporter,
)
if active_tag in tags:
tags.remove(active_tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down
Loading