diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 033f2331ae9..66f90ead413 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -769,33 +769,43 @@ def dump_operator_distribution( Returns self for daisy-chaining. """ line = "#" * 10 - to_print = f"{line} {self.cur} Operator Distribution {line}\n" + to_print = f"\n{line} {self.cur} Operator Distribution {line}\n" - if ( - self.cur - in ( - StageType.PARTITION, - StageType.TO_EDGE_TRANSFORM_AND_LOWER, - ) - and print_table + if self.cur in ( + StageType.PARTITION, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, ): graph_module = self.get_artifact().exported_program().graph_module delegation_info = get_delegation_info(graph_module) - if print_table: - op_dist = delegation_info.get_operator_delegation_dataframe() op_dist = _get_tosa_operator_distribution(graph_module, include_dtypes) - if include_dtypes: - op_dist = { - "Operator": [op_type[0] for op_type, _ in op_dist], - "Dtype": [op_type[1] for op_type, _ in op_dist], - "Count": [count for _, count in op_dist], - } + if print_table: + aten_op_dist = delegation_info.get_operator_delegation_dataframe() + to_print += "Aten operators:\n" + _format_dict( + dict(aten_op_dist), print_table + ) + + if include_dtypes: + op_dist_dict = { + "Operator": [op_type[0] for op_type, _ in op_dist], + "Dtype": [op_type[1] for op_type, _ in op_dist], + "Count": [count for _, count in op_dist], + } + else: + op_dist_dict = { + "Operator": [op for op, _ in op_dist], + "Count": [count for _, count in op_dist], + } else: - op_dist = { - "Operator": [op for op, _ in op_dist], - "Count": [count for _, count in op_dist], - } - to_print += "TOSA operators:\n" + _format_dict(dict(op_dist), print_table) + if include_dtypes: + op_dtype_dist_dict: Dict[str, Dict[str, int]] = defaultdict(dict) + for op_dtype, count in op_dist: + op = op_dtype[0] + dtype = op_dtype[1] + op_dtype_dist_dict[op].update({dtype: count}) + op_dist_dict = dict(op_dtype_dist_dict) + else: + op_dist_dict = dict(op_dist) # type: ignore[arg-type] + to_print += "\nTOSA operators:\n" + _format_dict(op_dist_dict, print_table) to_print += "\n" + delegation_info.get_summary() else: graph = self.get_graph(self.cur) @@ -805,17 +815,28 @@ def dump_operator_distribution( op_dist = _get_operator_distribution(graph) if print_table: if include_dtypes: - op_dist = { + op_dist_dict = { "Operator": [op_dtype[0] for op_dtype, _ in op_dist], "Dtype": [op_dtype[1] for op_dtype, _ in op_dist], "Count": [count for _, count in op_dist], } else: - op_dist = { + op_dist_dict = { "Operator": [op for op, _ in op_dist], "Count": [count for _, count in op_dist], } - to_print += _format_dict(op_dist, print_table) + "\n" + else: + if include_dtypes: + op_dtype_dist_dict = defaultdict(dict) + for op_dtype, count in op_dist: + op = op_dtype[0] + dtype = op_dtype[1] + op_dtype_dist_dict[op].update({dtype: count}) + op_dist_dict = dict(op_dtype_dist_dict) + else: + op_dist_dict = dict(op_dist) # type: ignore[arg-type] + + to_print += _format_dict(op_dist_dict, print_table) + "\n" _dump_str(to_print, path_to_dump)