55
66import abc
77import re
8+ from collections import defaultdict
9+ from copy import deepcopy
810from dataclasses import dataclass
9- from typing import Union
11+ from typing import Callable , Union
1012
13+ from executorch .backends .nxp .neutron_partitioner import (
14+ NeutronPartitioner ,
15+ NXP_DELEGATION_TAG ,
16+ )
17+ from executorch .backends .nxp .tests .ops_aliases import (
18+ DequantizePerChannel ,
19+ DequantizePerTensor ,
20+ QuantizePerChannel ,
21+ QuantizePerTensor ,
22+ )
23+
24+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
25+
26+ from pytest_mock import MockerFixture
27+
28+ from torch .fx import Node
1129from torch .fx .graph import Graph
1230
1331
1432@dataclass
1533class NonDelegatedNode :
34+ """Represents an expected non-delegated node in the graph.
35+
36+ :param node_name: The name of the node to check for
37+ :param num_occurrences: Expected number of occurrences. If None, just verifies that at least one exists
38+ """
39+
1640 node_name : str
1741 num_occurrences : Union [int , None ] = None
1842
1943
2044class GraphVerifier (abc .ABC ):
45+ """Abstract base class for graph verification strategies."""
46+
2147 @abc .abstractmethod
2248 def verify_graph (self , graph : Graph ):
23- pass
49+ """Verifies the graph meets expected criteria.
2450
25- @abc .abstractmethod
26- def check_num_delegated_nodes (self , num_dlg_nodes : int ):
51+ :param graph: The FX graph to verify
52+ :raises AssertionError: If the graph does not meet expectations
53+ """
2754 pass
2855
2956
3057class BaseGraphVerifier (GraphVerifier ):
31- """Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes."""
58+ """Graph verifier base class. Checks for number of delegated nodes and number of selected expected nodes.
59+
60+ This verifier performs the following checks:
61+ - The total number of delegated call nodes matches expectations
62+ - Specific non-delegated nodes appear with the expected frequency
63+ - No unexpected aten nodes are present in the graph
64+ """
3265
3366 def __init__ (
3467 self ,
3568 exp_num_delegate_call_nodes : int ,
3669 exp_non_delegated_nodes : list [NonDelegatedNode ] = None ,
3770 ):
71+ """Initializes the BaseGraphVerifier.
72+
73+ :param exp_num_delegate_call_nodes: Expected number of delegated nodes
74+ :param exp_non_delegated_nodes: List of expected non-delegated nodes to verify
75+ """
3876 self .exp_non_delegated_nodes = (
3977 exp_non_delegated_nodes if exp_non_delegated_nodes is not None else []
4078 )
4179 self .exp_num_delegate_call_nodes = exp_num_delegate_call_nodes
4280
4381 def check_num_delegated_nodes (self , num_dlg_nodes ):
82+ """Checks that the number of delegated nodes matches expectations.
83+
84+ :param num_dlg_nodes: Actual number of delegated nodes
85+ :raises AssertionError: If the count doesn't match expectations
86+ """
4487 assert not (
4588 num_dlg_nodes < self .exp_num_delegate_call_nodes
4689 ), f"Number of delegated nodes decreased from { self .exp_num_delegate_call_nodes } to { num_dlg_nodes } ."
@@ -49,6 +92,11 @@ def check_num_delegated_nodes(self, num_dlg_nodes):
4992 ), f"Number of delegated nodes increased from { self .exp_num_delegate_call_nodes } to { num_dlg_nodes } ."
5093
5194 def verify_graph (self , graph ):
95+ """Verifies the graph meets delegation and node presence expectations.
96+
97+ :param graph: The FX graph to verify
98+ :raises AssertionError: If verification fails
99+ """
52100 nodes = list (graph .nodes )
53101
54102 # Check for specific non delegated nodes
@@ -84,3 +132,132 @@ def verify_graph(self, graph):
84132 assert (
85133 not unexpected_aten_fn_nodes
86134 ), f"Graphs contains unexpected aten nodes:\n { unexpected_aten_fn_nodes } ."
135+
136+
137+ # Type alias for operators - can be either EdgeOpOverload or any callable (e.g., operator.getitem).
138+ Operator = EdgeOpOverload | Callable
139+
140+
141+ class DetailedGraphVerifier (GraphVerifier ):
142+ """Graph verifier that checks for exact delegated and non-delegated operators.
143+
144+ This verifier captures a snapshot of the graph immediately after partitioning and verifies
145+ that specific operators were delegated/non-delegated the expected number of times. It uses
146+ mocker to intercept the partition() call and create a deep copy of the nodes before they
147+ can be modified. Quantization/dequantization operators are ignored by default as they are
148+ typically not the focus of delegation verification.
149+ """
150+
151+ default_ops_to_ignore = {
152+ QuantizePerTensor ,
153+ QuantizePerChannel ,
154+ DequantizePerTensor ,
155+ DequantizePerChannel ,
156+ }
157+
158+ def __init__ (
159+ self ,
160+ expected_delegated_ops : dict [Operator , int ],
161+ expected_non_delegated_ops : dict [Operator , int ],
162+ mocker : MockerFixture ,
163+ ops_to_ignore : set [Operator ] | None = None ,
164+ ):
165+ """Initializes the DetailedGraphVerifier and patches NeutronPartitioner.partition() to capture node state.
166+
167+ :param expected_delegated_ops: Dictionary mapping operators to their expected delegation count
168+ :param expected_non_delegated_ops: Dictionary mapping operators to their expected non-delegation count
169+ :param mocker: Pytest mocker fixture for intercepting the partition method
170+ :param ops_to_ignore: Set of operators to ignore during verification. Defaults to quantization ops
171+ """
172+ self .expected_delegated_ops = expected_delegated_ops
173+ self .expected_non_delegated_ops = expected_non_delegated_ops
174+
175+ self .ops_to_ignore = ops_to_ignore or self .default_ops_to_ignore
176+
177+ # We need to use mocker to capture a copy of the nodes returned by NeutronPartitioner.partition() to access
178+ # their partition tag. The nodes in the returned graph may be modified after partition() returns, so we
179+ # capture a deep copy immediately when the method completes.
180+ self .captured_partitioned_nodes : list [Node ] | None = None
181+
182+ # Store original partition method for the wrapper.
183+ # Note: pytest-mock automatically restores the original method after the test completes,
184+ # so manual cleanup is not required.
185+ original_partition_method = NeutronPartitioner .partition
186+
187+ def partition_wrapper (self_ , exported_program ):
188+ """Wraps NeutronPartitioner.partition() to capture a snapshot of nodes after partitioning.
189+
190+ :param self_: The NeutronPartitioner instance
191+ :param exported_program: The ExportedProgram being partitioned
192+ :return: The PartitionResult from the original partition method
193+ """
194+ result = original_partition_method (self_ , exported_program )
195+ # Capture a deep copy of the nodes with their metadata.
196+ # This ensures we have the exact state immediately after partitioning,
197+ # before any subsequent transformations modify the graph.
198+ self .captured_partitioned_nodes = list (
199+ deepcopy (exported_program .graph .nodes )
200+ )
201+ return result
202+
203+ # Patch the partition method to intercept and capture results.
204+ mocker .patch .object (NeutronPartitioner , "partition" , partition_wrapper )
205+
206+ def verify_graph (self , graph ):
207+ """Verifies that operators were delegated/non-delegated as expected by comparing actual counts against expectations.
208+
209+ :param graph: The FX graph to verify (not directly used; we use captured nodes instead)
210+ :raises AssertionError: If the NeutronPartitioner wasn't used or if delegation doesn't match expectations
211+ """
212+ assert (
213+ self .captured_partitioned_nodes is not None
214+ ), "The NeutronPartitioner was not used. Cannot access delegated nodes."
215+
216+ delegated_ops = defaultdict (int )
217+ non_delegated_ops = defaultdict (int )
218+
219+ for node in self .captured_partitioned_nodes :
220+ # Only process call_function nodes with a target
221+ if not hasattr (node , "target" ) or node .op != "call_function" :
222+ continue
223+
224+ # Skip operators we're configured to ignore (e.g., quantization ops)
225+ if node .target in self .ops_to_ignore :
226+ continue
227+
228+ # Check if the node was tagged for delegation during partitioning
229+ if NXP_DELEGATION_TAG in node .meta :
230+ delegated_ops [node .target ] += 1
231+ else :
232+ non_delegated_ops [node .target ] += 1
233+
234+ # All ops which were either expected to be delegated, or were actually delegated.
235+ all_delegated_ops = list (set (self .expected_delegated_ops ).union (delegated_ops ))
236+
237+ # All ops which were either expected to be non-delegated, or were actually non-delegated.
238+ all_non_delegated_ops = list (
239+ set (self .expected_non_delegated_ops ).union (non_delegated_ops )
240+ )
241+
242+ message = ""
243+
244+ # Check delegated operators
245+ for op in all_delegated_ops :
246+ expected_count = self .expected_delegated_ops .get (op , 0 )
247+ real_count = delegated_ops .get (op , 0 )
248+ op_name = op .name () if hasattr (op , "name" ) else str (op )
249+ if expected_count != real_count :
250+ message += f"\t `{ op_name } ` was delegated { real_count } times instead of the expected { expected_count } times.\n "
251+
252+ # Check non-delegated operators
253+ for op in all_non_delegated_ops :
254+ expected_count = self .expected_non_delegated_ops .get (op , 0 )
255+ real_count = non_delegated_ops .get (op , 0 )
256+ op_name = op .name () if hasattr (op , "name" ) else str (op )
257+ if expected_count != real_count :
258+ message += f"\t `{ op_name } ` was NON-delegated { real_count } times instead of the expected { expected_count } times.\n "
259+
260+ if message :
261+ raise AssertionError (
262+ "Some operators were not delegated as expected:\n " + message
263+ )
0 commit comments