-
Notifications
You must be signed in to change notification settings - Fork 733
Description
🐛 Describe the bug
In arm_quantizer.py, the name filter (module_name_filter) assumes module names starts with "L['self']." and filters it out of the module name, but it doesn't contain that string, so the whole name is deleted and the node isn't detected by the filter.
class SimpleConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
return x
def create_int8_int8_config():
"""INT8 activations, INT8 weights, INT32 bias"""
return get_symmetric_quantization_config(is_per_channel=True)
def create_int16_int8_config():
"""INT16 activations, INT8 weights, INT32 bias"""
base_config = get_symmetric_quantization_config(is_per_channel=True)
# Replace activation specs with INT16
new_fields = {}
if hasattr(base_config, 'input_activation') and base_config.input_activation:
new_fields['input_activation'] = replace(
base_config.input_activation,
dtype=torch.int16,
quant_min=-32768,
quant_max=32767
)
if hasattr(base_config, 'output_activation') and base_config.output_activation:
new_fields['output_activation'] = replace(
base_config.output_activation,
dtype=torch.int16,
quant_min=-32768,
quant_max=32767
)
return replace(base_config, **new_fields)
# Create model
model = SimpleConvModel()
model.eval()
# Create dummy input
batch_size = 4
example_input = torch.randn(batch_size, 3, 32, 32)
# Export model
print("Exporting model...")
exported_program = torch.export.export(model, (example_input,))
graph_module = exported_program.module(check_guards=False)
# Setup compile spec
compile_spec = EthosUCompileSpec(
target="ethos-u55-128",
system_config="Ethos_U55_High_End_Embedded",
memory_mode="Shared_Sram",
extra_flags=["--output-format=raw", "--debug-force-regor"]
)
# Create quantizer
quantizer = EthosUQuantizer(compile_spec)
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True))
# Configure conv1: INT8 activations, INT8 weights
int8_config = create_int8_int8_config()
quantizer.set_module_name('conv1', int8_config)
print("\nConv1 config (INT8/INT8):")
print(f" Input activation dtype: {getattr(int8_config.input_activation, 'dtype', None)}")
print(f" Output activation dtype: {getattr(int8_config.output_activation, 'dtype', None)}")
print(f" Weight dtype: {getattr(int8_config.weight, 'dtype', None)}")
# Configure conv2: INT16 activations, INT8 weights
int16_config = create_int16_int8_config()
quantizer.set_module_name('conv2', int16_config)
# quantizer.set_module_name('aten.conv2d.default', int16_config)
print("\nConv2 config (INT16/INT8):")
print(f" Input activation dtype: {getattr(int16_config.input_activation, 'dtype', None)}")
print(f" Output activation dtype: {getattr(int16_config.output_activation, 'dtype', None)}")
print(f" Weight dtype: {getattr(int16_config.weight, 'dtype', None)}")
# Prepare model for quantization
print("\nPreparing model for quantization...")
prepared = prepare_pt2e(graph_module, quantizer)Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.0
[pip3] numpy==2.3.4
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.0
[pip3] torchcodec==0.8.1
[pip3] torchvision==0.24.0
[conda] Could not collect