From 92150ffd09f1f1f5738b67d377c0eede7a815d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 18 Jun 2026 15:55:26 +0200 Subject: [PATCH] Arm backend: Add extra_repr methods to MXFP modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add extra_repr methods to MXFPLinearOp and MXFPConv2dOp to make them show more detailed info when printed. Signed-off-by: Martin Lindström Change-Id: Id412b7da6369304a087f1a392f10278cab022533 --- backends/arm/ao_ext/ops/mxfp_conv2d_op.py | 21 +++++++++++++++++++++ backends/arm/ao_ext/ops/mxfp_linear_op.py | 16 ++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/backends/arm/ao_ext/ops/mxfp_conv2d_op.py b/backends/arm/ao_ext/ops/mxfp_conv2d_op.py index 0418db5d185..8297fe47b01 100644 --- a/backends/arm/ao_ext/ops/mxfp_conv2d_op.py +++ b/backends/arm/ao_ext/ops/mxfp_conv2d_op.py @@ -10,8 +10,11 @@ """ +from typing import cast + import torch import torch.nn.functional as F + from executorch.backends.arm.ao_ext.mxfp import ( _cast_to_block_scaled_cpu_ref, mxfp_dtype_to_str, @@ -257,6 +260,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.to(self.output_dtype) return output + def extra_repr(self) -> str: + weight_qdata = cast(torch.Tensor, self.weight_qdata) + weight_shape = weight_qdata.shape + in_channels = _get_num_input_channels(weight_qdata, self.weight_dtype) + repr_parts = [ + f"in_channels={in_channels}", + f"out_channels={weight_shape[0]}", + f"kernel_size={(weight_shape[1], weight_shape[2])}", + f"stride={self.stride}", + f"padding={self.padding}", + f"dilation={self.dilation}", + f"groups={self.groups}", + f"bias={self.bias is not None}", + f"weight_dtype={self.weight_dtype}", + f"block_size={self.block_size}", + ] + return ", ".join(repr_parts) + def transform_conv2d_to_mxfp( module: torch.nn.Module, diff --git a/backends/arm/ao_ext/ops/mxfp_linear_op.py b/backends/arm/ao_ext/ops/mxfp_linear_op.py index 1bd1477b674..d4f674980e2 100644 --- a/backends/arm/ao_ext/ops/mxfp_linear_op.py +++ b/backends/arm/ao_ext/ops/mxfp_linear_op.py @@ -10,8 +10,11 @@ """ +from typing import cast + import torch import torch.nn.functional as F + from executorch.backends.arm.ao_ext.mxfp import ( _cast_to_block_scaled_cpu_ref, mxfp_dtype_to_str, @@ -179,6 +182,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.to(self.output_dtype) return output + def extra_repr(self) -> str: + weight_qdata = cast(torch.Tensor, self.weight_qdata) + weight_shape = weight_qdata.shape + in_features = _get_num_input_features(weight_qdata, self.weight_dtype) + repr_parts = [ + f"in_features={in_features}", + f"out_features={weight_shape[1]}", + f"bias={self.bias is not None}", + f"weight_dtype={self.weight_dtype}", + f"block_size={self.block_size}", + ] + return ", ".join(repr_parts) + def transform_linear_to_mxfp( module: torch.nn.Module,