diff --git a/backends/arm/ao_ext/ops/mxfp_conv2d_op.py b/backends/arm/ao_ext/ops/mxfp_conv2d_op.py index 0418db5d185..8297fe47b01 100644 --- a/backends/arm/ao_ext/ops/mxfp_conv2d_op.py +++ b/backends/arm/ao_ext/ops/mxfp_conv2d_op.py @@ -10,8 +10,11 @@ """ +from typing import cast + import torch import torch.nn.functional as F + from executorch.backends.arm.ao_ext.mxfp import ( _cast_to_block_scaled_cpu_ref, mxfp_dtype_to_str, @@ -257,6 +260,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.to(self.output_dtype) return output + def extra_repr(self) -> str: + weight_qdata = cast(torch.Tensor, self.weight_qdata) + weight_shape = weight_qdata.shape + in_channels = _get_num_input_channels(weight_qdata, self.weight_dtype) + repr_parts = [ + f"in_channels={in_channels}", + f"out_channels={weight_shape[0]}", + f"kernel_size={(weight_shape[1], weight_shape[2])}", + f"stride={self.stride}", + f"padding={self.padding}", + f"dilation={self.dilation}", + f"groups={self.groups}", + f"bias={self.bias is not None}", + f"weight_dtype={self.weight_dtype}", + f"block_size={self.block_size}", + ] + return ", ".join(repr_parts) + def transform_conv2d_to_mxfp( module: torch.nn.Module, diff --git a/backends/arm/ao_ext/ops/mxfp_linear_op.py b/backends/arm/ao_ext/ops/mxfp_linear_op.py index 1bd1477b674..d4f674980e2 100644 --- a/backends/arm/ao_ext/ops/mxfp_linear_op.py +++ b/backends/arm/ao_ext/ops/mxfp_linear_op.py @@ -10,8 +10,11 @@ """ +from typing import cast + import torch import torch.nn.functional as F + from executorch.backends.arm.ao_ext.mxfp import ( _cast_to_block_scaled_cpu_ref, mxfp_dtype_to_str, @@ -179,6 +182,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.to(self.output_dtype) return output + def extra_repr(self) -> str: + weight_qdata = cast(torch.Tensor, self.weight_qdata) + weight_shape = weight_qdata.shape + in_features = _get_num_input_features(weight_qdata, self.weight_dtype) + repr_parts = [ + f"in_features={in_features}", + f"out_features={weight_shape[1]}", + f"bias={self.bias is not None}", + f"weight_dtype={self.weight_dtype}", + f"block_size={self.block_size}", + ] + return ", ".join(repr_parts) + def transform_linear_to_mxfp( module: torch.nn.Module,