Skip to content

Commit a70c6c6

Browse files
committed
Arm backend: Add FP4E2M1 support to MXFP linear
Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I29ca5ed16db5fe15331402e5139aead1040ce1b6
1 parent 0da9ca3 commit a70c6c6

13 files changed

Lines changed: 355 additions & 86 deletions

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from executorch.backends.arm._passes.quant_args import QuantArgs
2020
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
21+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2122
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.pass_base import ExportPass, PassResult
2324
from torch.fx import GraphModule, Node
@@ -49,6 +50,8 @@ def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
4950
continue
5051
if meta_val.dtype != torch.uint8:
5152
continue
53+
if node.meta.get(TosaSpecialDtype.meta_key()) == TosaSpecialDtype.FP4E2M1:
54+
continue
5255
if node.op in ("placeholder", "output"):
5356
continue
5457
if node.op == "call_function" and node.target == operator.getitem:

backends/arm/_passes/rewrite_mxfp_linear.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,22 @@
1313
create_node,
1414
get_first_fake_tensor,
1515
)
16+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1617
from executorch.exir.dialects._ops import ops as exir_ops
1718
from executorch.exir.pass_base import ExportPass, PassResult
1819

1920

21+
def _get_block_scaled_payload_dtype(qdata: torch.Tensor) -> torch.dtype:
22+
if qdata.dtype == torch.uint8:
23+
return torch.float4_e2m1fn_x2
24+
return qdata.dtype
25+
26+
27+
def _mark_fp4_payload(node: torch.fx.Node, payload_dtype: torch.dtype) -> None:
28+
if payload_dtype == torch.float4_e2m1fn_x2:
29+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.FP4E2M1
30+
31+
2032
class RewriteMXFPLinearPass(ArmPass):
2133
"""Rewrite ``tosa_mxfp.linear`` into explicit TOSA MXFP operators.
2234
@@ -90,6 +102,8 @@ def _create_block_scaled_inputs(
90102
input_fake = get_first_fake_tensor(input_node)
91103
weight_qdata_fake = get_first_fake_tensor(weight_qdata_node)
92104
weight_scale_fake = get_first_fake_tensor(weight_scale_node)
105+
weight_dtype = _get_block_scaled_payload_dtype(weight_qdata_fake)
106+
_mark_fp4_payload(weight_qdata_node, weight_dtype)
93107

94108
batches = reduce(operator.mul, input_fake.shape[:-1], 1)
95109
input_reshape_shape = [1, batches, input_fake.shape[-1]]
@@ -109,13 +123,13 @@ def _create_block_scaled_inputs(
109123
graph=graph,
110124
op_target=exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default,
111125
args=(input_reshaped, block_size),
112-
kwargs={"output_dtype": weight_qdata_fake.dtype},
126+
kwargs={"output_dtype": weight_dtype},
113127
from_node=mxfp_linear_node,
114128
)
115129
cast_node.meta["val"] = exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default(
116130
get_first_fake_tensor(input_reshaped),
117131
block_size,
118-
output_dtype=weight_qdata_fake.dtype,
132+
output_dtype=weight_dtype,
119133
)
120134

121135
input_qdata_node = create_node(
@@ -126,6 +140,7 @@ def _create_block_scaled_inputs(
126140
from_node=mxfp_linear_node,
127141
)
128142
input_qdata_node.meta["val"] = cast_node.meta["val"][0]
143+
_mark_fp4_payload(input_qdata_node, weight_dtype)
129144

130145
input_scale_node = create_node(
131146
graph=graph,

backends/arm/ao_ext/mxfp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ def block_size(self) -> int:
3232
return 32
3333

3434
def __post_init__(self) -> None:
35-
if self.weight_dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
35+
if self.weight_dtype not in (
36+
torch.float4_e2m1fn_x2,
37+
torch.float8_e4m3fn,
38+
torch.float8_e5m2,
39+
):
3640
raise ValueError(f"Unsupported weight_dtype: {self.weight_dtype}")
3741
if not isinstance(self.weight_scaling_mode, ScaleCalculationMode):
3842
raise ValueError(

backends/arm/ao_ext/ops/mxfp_linear_op.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@
2323
)
2424

2525

26+
def _get_mx_elem_dtype(weight_qdata: torch.Tensor) -> torch.dtype:
27+
if weight_qdata.dtype == torch.uint8:
28+
return torch.float4_e2m1fn_x2
29+
return weight_qdata.dtype
30+
31+
32+
def _get_num_input_features(weight_qdata: torch.Tensor) -> int:
33+
num_input_features = weight_qdata.shape[-1]
34+
if _get_mx_elem_dtype(weight_qdata) == torch.float4_e2m1fn_x2:
35+
num_input_features *= 2
36+
return num_input_features
37+
38+
2639
@torch.library.register_fake("tosa_mxfp::linear", lib=MXFP_TOSA_LIB) # type: ignore[misc]
2740
def _mxfp_linear_fake(
2841
input: torch.Tensor,
@@ -39,15 +52,16 @@ def _mxfp_linear_fake(
3952
raise ValueError(
4053
f"Expected weight_qdata batch dim to be 1, got {weight_qdata.shape[0]}"
4154
)
42-
if input.shape[-1] != weight_qdata.shape[-1]:
55+
num_input_features = _get_num_input_features(weight_qdata)
56+
if input.shape[-1] != num_input_features:
4357
raise ValueError(
4458
f"Input last dim {input.shape[-1]} must match linear in_features "
45-
f"{weight_qdata.shape[-1]}"
59+
f"{num_input_features}"
4660
)
4761
expected_scale_shape = (
4862
1,
4963
weight_qdata.shape[1],
50-
weight_qdata.shape[-1] // block_size,
64+
num_input_features // block_size,
5165
)
5266
if tuple(weight_scale.shape) != expected_scale_shape:
5367
raise ValueError(
@@ -92,17 +106,19 @@ def _mxfp_linear_cpu(
92106
if weight_qdata.ndim != 3 or weight_scale.ndim != 3:
93107
raise ValueError("Expected rank-3 weight tensors for MXFP linear")
94108

109+
elem_dtype = _get_mx_elem_dtype(weight_qdata)
110+
95111
# Cast the input to block-scaled format and back again to match the
96112
# expected input format of the TOSA
97113
dequantized_input = _cast_to_block_scaled_cpu_ref(
98114
input,
99-
weight_qdata.dtype,
115+
elem_dtype,
100116
block_size,
101117
)
102118
dequantized_weight = to_dtype(
103119
weight_qdata,
104120
weight_scale,
105-
weight_qdata.dtype,
121+
elem_dtype,
106122
block_size,
107123
torch.float32,
108124
)

backends/arm/operators/op_tosa_matmul_t_block_scaled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def define_node(
5353
validate_valid_dtype(
5454
self.target,
5555
[A_data, B_data],
56-
[ts.DType.FP8E4M3, ts.DType.FP8E5M2],
56+
[ts.DType.FP4E2M1, ts.DType.FP8E4M3, ts.DType.FP8E5M2],
5757
self.tosa_spec,
5858
)
5959
validate_valid_dtype(

backends/arm/process_node.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,68 @@ def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
5656

5757

5858
def _prepare_const_values_for_tosa_dtype(
59-
values: np.ndarray, tosa_dtype: ts.DType
59+
values: np.ndarray, tosa_arg: TosaArg
6060
) -> np.ndarray:
6161
"""Normalize constant storage to the expected TOSA serializer dtype."""
62-
if tosa_dtype == ts.DType.INT48 and values.dtype != np.int64:
62+
if tosa_arg.dtype == ts.DType.INT48 and values.dtype != np.int64:
6363
return values.astype(np.int64)
6464
return values
6565

6666

67+
def _get_const_shape(values: np.ndarray, tosa_arg: TosaArg) -> list[int]:
68+
"""Return the TOSA logical shape for a serialized constant."""
69+
if tosa_arg.dtype == ts.DType.FP4E2M1:
70+
return normalize_symint(tosa_arg.shape)
71+
return normalize_symint(values.shape)
72+
73+
74+
def _is_packed_fp4_const(values: np.ndarray, tosa_arg: TosaArg) -> bool:
75+
"""FP4 elements are pairwise in each byte of a uint8 tensor.
76+
77+
This function checks if the given values and TOSA argument represent a
78+
packed FP4 constant.
79+
80+
"""
81+
82+
return (
83+
tosa_arg.dtype == ts.DType.FP4E2M1
84+
and values.dtype == np.uint8
85+
and values.shape[-1] * 2 == tosa_arg.shape[-1]
86+
)
87+
88+
89+
def _add_const(
90+
tosa_graph: Any,
91+
values: np.ndarray,
92+
tosa_arg: TosaArg,
93+
name: str,
94+
) -> None:
95+
"""Add a constant, preserving packed FP4 storage when required."""
96+
if _is_packed_fp4_const(values, tosa_arg):
97+
# TOSA FP4 tensors have logical FP4 shape, but constants are stored as
98+
# packed bytes (two values per byte). Add the raw bytes as INT8 first
99+
# then set TOSA dtype and shape correctly on the tensor metadata.
100+
tosa_graph.addConst(
101+
normalize_symint(values.shape),
102+
ts.DType.INT8,
103+
values,
104+
name=name,
105+
)
106+
tensor = tosa_graph.currRegion.currBasicBlock.tensors[name]
107+
tensor.setDtype(ts.DType.FP4E2M1)
108+
for dim, size in enumerate(normalize_symint(tosa_arg.shape)):
109+
tensor.SetDimSize(dim, size)
110+
return
111+
112+
prepared_values = _prepare_const_values_for_tosa_dtype(values, tosa_arg)
113+
tosa_graph.addConst(
114+
_get_const_shape(prepared_values, tosa_arg),
115+
tosa_arg.dtype,
116+
prepared_values,
117+
name=name,
118+
)
119+
120+
67121
def process_call_function(
68122
node: torch.fx.Node,
69123
tosa_graph: Any,
@@ -154,16 +208,7 @@ def process_inputs_to_parameters(
154208
f"{type(parameter_data).__name__}"
155209
)
156210
parameter_values = _tensor_to_numpy(parameter_data)
157-
parameter_values = _prepare_const_values_for_tosa_dtype(
158-
parameter_values, tosa_arg.dtype
159-
)
160-
161-
tosa_graph.addConst(
162-
normalize_symint(parameter_values.shape),
163-
tosa_arg.dtype,
164-
parameter_values,
165-
name=tosa_arg.name,
166-
)
211+
_add_const(tosa_graph, parameter_values, tosa_arg, name=tosa_arg.name)
167212

168213

169214
def process_inputs_to_buffers(
@@ -188,14 +233,7 @@ def process_inputs_to_buffers(
188233
f"{type(buffer_data).__name__}"
189234
)
190235
buffer_values = _tensor_to_numpy(buffer_data)
191-
buffer_values = _prepare_const_values_for_tosa_dtype(buffer_values, tosa_arg.dtype)
192-
193-
tosa_graph.addConst(
194-
normalize_symint(buffer_values.shape),
195-
tosa_arg.dtype,
196-
buffer_values,
197-
name=tosa_arg.name,
198-
)
236+
_add_const(tosa_graph, buffer_values, tosa_arg, name=tosa_arg.name)
199237

200238

201239
def process_inputs_to_lifted_tensor_constants(
@@ -217,14 +255,7 @@ def process_inputs_to_lifted_tensor_constants(
217255
f"{type(tensor).__name__}"
218256
)
219257
tensor_values = _tensor_to_numpy(tensor)
220-
tensor_values = _prepare_const_values_for_tosa_dtype(tensor_values, tosa_arg.dtype)
221-
222-
tosa_graph.addConst(
223-
normalize_symint(tensor_values.shape),
224-
tosa_arg.dtype,
225-
tensor_values,
226-
name=tosa_arg.name,
227-
)
258+
_add_const(tosa_graph, tensor_values, tosa_arg, name=tosa_arg.name)
228259

229260

230261
def _is_submodule_input(

backends/arm/test/misc/test_mxfp_linear_ao.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,45 @@ def test_mxfp_linear_quantize_swaps_module() -> None:
3131
assert tuple(model.linear.weight_scale.shape) == (1, 8, 1)
3232

3333

34-
def test_mxfp_linear_export_preserves_custom_op() -> None:
34+
def test_mxfp4_linear_quantize_swaps_module() -> None:
3535
model = LinearModule().eval()
36-
to_mxfp(model, MXFPOpConfig())
36+
37+
to_mxfp(
38+
model,
39+
MXFPOpConfig(weight_dtype=torch.float4_e2m1fn_x2),
40+
)
41+
42+
assert isinstance(model.linear, MXFPLinearOp)
43+
assert model.linear.weight_qdata.dtype == torch.uint8
44+
assert model.linear.weight_scale.dtype == torch.float8_e8m0fnu
45+
assert tuple(model.linear.weight_qdata.shape) == (1, 8, 16)
46+
assert tuple(model.linear.weight_scale.shape) == (1, 8, 1)
47+
48+
49+
def test_mxfp_linear_quantize_filter_fn_selects_modules() -> None:
50+
class TwoLinearModule(torch.nn.Module):
51+
def __init__(self) -> None:
52+
super().__init__()
53+
self.selected = torch.nn.Linear(32, 8)
54+
self.skipped = torch.nn.Linear(32, 8)
55+
56+
def forward(self, x: torch.Tensor) -> torch.Tensor:
57+
return self.selected(x) + self.skipped(x)
58+
59+
def _is_selected_linear(module: torch.nn.Module, fqn: str) -> bool:
60+
return isinstance(module, torch.nn.Linear) and fqn == "selected"
61+
62+
model = TwoLinearModule().eval()
63+
64+
to_mxfp(model, MXFPOpConfig(), filter_fn=_is_selected_linear)
65+
66+
assert isinstance(model.selected, MXFPLinearOp)
67+
assert isinstance(model.skipped, torch.nn.Linear)
68+
69+
70+
def _test_mxfp_linear_export_preserves_custom_op(config: MXFPOpConfig) -> None:
71+
model = LinearModule().eval()
72+
to_mxfp(model, config)
3773

3874
exported = export(model, (torch.randn(4, 32),), strict=False)
3975

@@ -44,3 +80,13 @@ def test_mxfp_linear_export_preserves_custom_op() -> None:
4480
]
4581

4682
assert torch.ops.tosa_mxfp.linear.default in targets
83+
84+
85+
def test_mxfp_linear_export_preserves_custom_op() -> None:
86+
_test_mxfp_linear_export_preserves_custom_op(MXFPOpConfig())
87+
88+
89+
def test_mxfp4_linear_export_preserves_custom_op() -> None:
90+
_test_mxfp_linear_export_preserves_custom_op(
91+
MXFPOpConfig(weight_dtype=torch.float4_e2m1fn_x2)
92+
)

backends/arm/test/misc/test_process_node.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from types import SimpleNamespace
7+
from typing import cast
8+
69
import numpy as np
710
import torch
811
import tosa_serializer as ts
9-
from executorch.backends.arm.process_node import process_placeholder
10-
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
12+
from executorch.backends.arm.process_node import _add_const, process_placeholder
13+
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
1114
from executorch.backends.arm.tosa.specification import TosaSpecification
1215
from executorch.exir import to_edge
1316
from torch._export.utils import is_param
17+
from tosa.TosaGraph import TosaGraph # type: ignore[import-untyped]
1418

1519

1620
class Int32BiasModule(torch.nn.Module):
@@ -94,3 +98,35 @@ def test_process_placeholder_int48_normalizes_int32_const_values() -> None:
9498
assert tosa_graph.values is not None
9599
assert tosa_graph.values.dtype == np.int64
96100
assert tosa_graph.serialized_bytes == _expected_int48_bytes(module.bias)
101+
102+
103+
def test_add_const_fp4_in_packed_storage() -> None:
104+
packed_values = np.array([0xDE, 0xFE, 0x6D, 0x55], dtype=np.uint8).reshape(
105+
1,
106+
1,
107+
4,
108+
)
109+
tosa_arg = cast(
110+
TosaArg,
111+
SimpleNamespace(dtype=ts.DType.FP4E2M1, shape=(1, 1, 8)),
112+
)
113+
tosa_graph = ts.TosaSerializer()
114+
115+
_add_const(tosa_graph, packed_values, tosa_arg, name="fp4_weight")
116+
117+
graph = TosaGraph.GetRootAs(bytes(tosa_graph.serialize()), 0)
118+
block = graph.Regions(0).Blocks(0)
119+
tensors = {
120+
block.Tensors(index).Name().decode(): block.Tensors(index)
121+
for index in range(block.TensorsLength())
122+
}
123+
tensor = tensors["fp4_weight"]
124+
125+
assert tensor.Type() == ts.DType.FP4E2M1
126+
assert [tensor.Shape(index) for index in range(tensor.ShapeLength())] == [1, 1, 8]
127+
assert [tensor.Data(index) for index in range(tensor.DataLength())] == [
128+
0xDE,
129+
0xFE,
130+
0x6D,
131+
0x55,
132+
]

0 commit comments

Comments
 (0)