Skip to content

Commit 88ba269

Browse files
committed
Arm backend: Add MXFP6 linear support
Add support for running torch.nn.Linear modules in MXFP6E3M2 and MXFP6E2M3. Update the `MXFPOpConfig` to support the data types. Since `torch` lacks FP6 datatypes, the string-based definitions in `torchao` are used as a workaround. The custom TOSA op receives a new argument called `str weight_payload_dtype` which tells which dtype is used for the weights. The weight tensor itself does not contain this info since the FP4 and FP6 formats are storted into uint8 tensors. The CAST_TO_BLOCK_SCALED required to transform activations from/to MXFP is also updated to support the new datatype. Its custom TOSA op gets a `str output_dtype` similar to the `weight_payload_dtype` for the MXFP linear op. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Iea06b859429c458b1921ee3e736a65f597b37239
1 parent a70c6c6 commit 88ba269

15 files changed

Lines changed: 530 additions & 88 deletions

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class InsertRescalePass(ArmPass):
3636

3737
_passes_required_after: Set[Type[ExportPass]] = set()
3838

39+
_mxfp_payload_dtypes = {
40+
TosaSpecialDtype.FP4E2M1,
41+
TosaSpecialDtype.FP6E2M3,
42+
TosaSpecialDtype.FP6E3M2,
43+
}
44+
3945
def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
4046
"""Ensure uint8 tensors only appear at IO boundaries.
4147
@@ -50,25 +56,25 @@ def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
5056
continue
5157
if meta_val.dtype != torch.uint8:
5258
continue
53-
if node.meta.get(TosaSpecialDtype.meta_key()) == TosaSpecialDtype.FP4E2M1:
54-
continue
5559
if node.op in ("placeholder", "output"):
5660
continue
57-
if node.op == "call_function" and node.target == operator.getitem:
58-
if all(user.op == "output" for user in node.users):
61+
if node.op == "call_function":
62+
if node.target == operator.getitem and all(
63+
user.op == "output" for user in node.users
64+
):
5965
continue
60-
if (
61-
node.op == "call_function"
62-
and node.target
63-
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
64-
):
65-
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
66-
continue
67-
if (
68-
node.op == "call_function"
69-
and node.target == exir_ops.backend.tosa.RESCALE.default
70-
):
66+
if node.target == exir_ops.backend.tosa.RESCALE.default:
67+
continue
68+
if (
69+
node.target
70+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
71+
):
72+
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
73+
continue
74+
if node.meta.get(TosaSpecialDtype.meta_key()) in self._mxfp_payload_dtypes:
75+
# Sub-byte FP types are stored uint8 arrays, so we need an exception for those.
7176
continue
77+
7278
raise ValueError(
7379
f"Found internal uint8 tensor at node {node.name} "
7480
f"({node.target}). Uint8 is only allowed at IO boundaries."

backends/arm/_passes/rewrite_mxfp_linear.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,53 @@
88
from typing import Any, cast, Sequence, Set, Type
99

1010
import torch
11-
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes import ArmOpTargetedPass
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
1515
)
16+
from executorch.backends.arm.ao_ext.mxfp import (
17+
mxfp_dtype_to_str,
18+
mxfp_str_to_dtype,
19+
MXFPDType,
20+
)
1621
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1722
from executorch.exir.dialects._ops import ops as exir_ops
1823
from executorch.exir.pass_base import ExportPass, PassResult
24+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2
1925

2026

21-
def _get_block_scaled_payload_dtype(qdata: torch.Tensor) -> torch.dtype:
27+
def _get_weights_payload_dtype(
28+
qdata_node: torch.fx.Node,
29+
dtype: str = "",
30+
) -> MXFPDType:
31+
if dtype:
32+
return mxfp_str_to_dtype(dtype)
33+
qdata = get_first_fake_tensor(qdata_node)
2234
if qdata.dtype == torch.uint8:
2335
return torch.float4_e2m1fn_x2
2436
return qdata.dtype
2537

2638

27-
def _mark_fp4_payload(node: torch.fx.Node, payload_dtype: torch.dtype) -> None:
39+
def _mark_mxfp_payload(node: torch.fx.Node, payload_dtype: MXFPDType) -> None:
40+
"""Annotate uint8-backed MXFP payload nodes with their TOSA dtype.
41+
42+
PyTorch represents sub-byte MXFP payloads as ``torch.uint8`` tensors, so
43+
the tensor dtype alone cannot distinguish FP4E2M1, FP6E2M3, and FP6E3M2.
44+
Store the logical TOSA dtype in node metadata so later lowering and
45+
serialization treat the payload as MXFP data rather than ordinary uint8.
46+
FP8 payloads have native PyTorch dtypes and do not need this metadata.
47+
48+
"""
2849
if payload_dtype == torch.float4_e2m1fn_x2:
2950
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.FP4E2M1
51+
elif payload_dtype == DTYPE_FP6_E2M3:
52+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.FP6E2M3
53+
elif payload_dtype == DTYPE_FP6_E3M2:
54+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.FP6E3M2
3055

3156

32-
class RewriteMXFPLinearPass(ArmPass):
57+
class RewriteMXFPLinearPass(ArmOpTargetedPass):
3358
"""Rewrite ``tosa_mxfp.linear`` into explicit TOSA MXFP operators.
3459
3560
For each MXFP linear custom op, the pass:
@@ -44,15 +69,24 @@ class RewriteMXFPLinearPass(ArmPass):
4469
4570
"""
4671

72+
target_ops = {
73+
torch.ops.tosa_mxfp.linear.default,
74+
exir_ops.edge.tosa_mxfp.linear.default,
75+
}
4776
_passes_required_after: Set[Type[ExportPass]] = set()
4877

4978
def __init__(self, exported_program: torch.export.ExportedProgram, *args, **kwargs):
5079
super().__init__(*args, **kwargs)
5180
self.exported_program = exported_program
5281

53-
def _get_linear_args(
54-
self, node: torch.fx.Node
55-
) -> tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node, torch.fx.Node | None, int]:
82+
def _get_linear_args(self, node: torch.fx.Node) -> tuple[
83+
torch.fx.Node,
84+
torch.fx.Node,
85+
torch.fx.Node,
86+
torch.fx.Node | None,
87+
int,
88+
MXFPDType,
89+
]:
5690
"""Extract the MXFP linear operands from a custom-op node."""
5791
input_node = cast(torch.fx.Node, node.args[0])
5892
weight_qdata_node = cast(torch.fx.Node, node.args[1])
@@ -65,7 +99,26 @@ def _get_linear_args(
6599
int,
66100
node.args[4] if len(node.args) > 4 else node.kwargs.get("block_size", 32),
67101
)
68-
return input_node, weight_qdata_node, weight_scale_node, bias_node, block_size
102+
payload_dtype_str = cast(
103+
str,
104+
(
105+
node.args[5]
106+
if len(node.args) > 5
107+
else node.kwargs.get(
108+
"weight_payload_dtype",
109+
node.kwargs.get("weight_dtype", ""),
110+
)
111+
),
112+
)
113+
payload_dtype = _get_weights_payload_dtype(weight_qdata_node, payload_dtype_str)
114+
return (
115+
input_node,
116+
weight_qdata_node,
117+
weight_scale_node,
118+
bias_node,
119+
block_size,
120+
payload_dtype,
121+
)
69122

70123
def _reshape_with_view(
71124
self,
@@ -96,14 +149,15 @@ def _create_block_scaled_inputs(
96149
weight_qdata_node: torch.fx.Node,
97150
weight_scale_node: torch.fx.Node,
98151
block_size: int,
152+
payload_dtype: MXFPDType,
99153
) -> tuple[torch.fx.Node, torch.fx.Node]:
100154
"""Create rank-3 inputs for the block-scaled cast and matmul ops."""
101155
graph = graph_module.graph
102156
input_fake = get_first_fake_tensor(input_node)
103157
weight_qdata_fake = get_first_fake_tensor(weight_qdata_node)
104158
weight_scale_fake = get_first_fake_tensor(weight_scale_node)
105-
weight_dtype = _get_block_scaled_payload_dtype(weight_qdata_fake)
106-
_mark_fp4_payload(weight_qdata_node, weight_dtype)
159+
payload_dtype_str = mxfp_dtype_to_str(payload_dtype)
160+
_mark_mxfp_payload(weight_qdata_node, payload_dtype)
107161

108162
batches = reduce(operator.mul, input_fake.shape[:-1], 1)
109163
input_reshape_shape = [1, batches, input_fake.shape[-1]]
@@ -123,13 +177,13 @@ def _create_block_scaled_inputs(
123177
graph=graph,
124178
op_target=exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default,
125179
args=(input_reshaped, block_size),
126-
kwargs={"output_dtype": weight_dtype},
180+
kwargs={"output_dtype": payload_dtype_str},
127181
from_node=mxfp_linear_node,
128182
)
129183
cast_node.meta["val"] = exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default(
130184
get_first_fake_tensor(input_reshaped),
131185
block_size,
132-
output_dtype=weight_dtype,
186+
output_dtype=payload_dtype_str,
133187
)
134188

135189
input_qdata_node = create_node(
@@ -140,7 +194,7 @@ def _create_block_scaled_inputs(
140194
from_node=mxfp_linear_node,
141195
)
142196
input_qdata_node.meta["val"] = cast_node.meta["val"][0]
143-
_mark_fp4_payload(input_qdata_node, weight_dtype)
197+
_mark_mxfp_payload(input_qdata_node, payload_dtype)
144198

145199
input_scale_node = create_node(
146200
graph=graph,
@@ -165,8 +219,10 @@ def _create_matmul_node(
165219
weight_qdata_node: torch.fx.Node,
166220
weight_scale_node: torch.fx.Node,
167221
block_size: int,
222+
payload_dtype: MXFPDType,
168223
) -> torch.fx.Node:
169224
"""Insert ``MATMUL_T_BLOCK_SCALED`` with updated fake metadata."""
225+
payload_dtype_str = mxfp_dtype_to_str(payload_dtype)
170226
matmul_node = create_node(
171227
graph=graph_module.graph,
172228
op_target=exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default,
@@ -177,7 +233,7 @@ def _create_matmul_node(
177233
weight_scale_node,
178234
block_size,
179235
),
180-
kwargs={},
236+
kwargs={"payload_dtype": payload_dtype_str},
181237
from_node=mxfp_linear_node,
182238
)
183239
matmul_node.meta["val"] = exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default(
@@ -186,6 +242,7 @@ def _create_matmul_node(
186242
get_first_fake_tensor(weight_qdata_node),
187243
get_first_fake_tensor(weight_scale_node),
188244
block_size,
245+
payload_dtype=payload_dtype_str,
189246
)
190247
return matmul_node
191248

@@ -270,6 +327,7 @@ def _rewrite_mxfp_linear_node(
270327
weight_scale_node,
271328
bias_node,
272329
block_size,
330+
payload_dtype,
273331
) = self._get_linear_args(mxfp_linear_node)
274332

275333
with graph.inserting_before(mxfp_linear_node):
@@ -283,6 +341,7 @@ def _rewrite_mxfp_linear_node(
283341
weight_qdata_node,
284342
weight_scale_node,
285343
block_size,
344+
payload_dtype,
286345
)
287346
matmul_node = self._create_matmul_node(
288347
graph_module,
@@ -292,6 +351,7 @@ def _rewrite_mxfp_linear_node(
292351
weight_qdata_node,
293352
weight_scale_node,
294353
block_size,
354+
payload_dtype,
295355
)
296356

297357
with graph.inserting_after(matmul_node):
@@ -314,10 +374,7 @@ def call(self, graph_module: torch.fx.GraphModule):
314374
graph = graph_module.graph
315375

316376
for node in list(graph.nodes):
317-
if node.op != "call_function" or node.target not in (
318-
torch.ops.tosa_mxfp.linear.default,
319-
exir_ops.edge.tosa_mxfp.linear.default,
320-
):
377+
if node.op != "call_function" or node.target not in self.target_ops:
321378
continue
322379

323380
modified = True

backends/arm/ao_ext/mxfp.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,56 @@
1010
from executorch.exir._warnings import experimental
1111
from torchao.core.config import AOBaseConfig
1212
from torchao.prototype.mx_formats.config import ScaleCalculationMode
13+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2
1314
from torchao.quantization import quantize_
1415

1516

17+
# Pytorch lacks dtypes for the FP6 types, so we use ao's string representations for those.
18+
MXFPDType = torch.dtype | str
19+
20+
21+
SUPPORTED_MXFP_DTYPES: set[MXFPDType] = {
22+
torch.float4_e2m1fn_x2,
23+
torch.float8_e4m3fn,
24+
torch.float8_e5m2,
25+
# Use ao's string representations.
26+
DTYPE_FP6_E2M3,
27+
DTYPE_FP6_E3M2,
28+
}
29+
30+
31+
_DTYPE_TO_STR: dict[MXFPDType, str] = {
32+
DTYPE_FP6_E2M3: "fp6e2m3",
33+
DTYPE_FP6_E3M2: "fp6e3m2",
34+
torch.float4_e2m1fn_x2: "f4e2m1",
35+
torch.float8_e4m3fn: "f8e4m3",
36+
torch.float8_e5m2: "f8e5m2",
37+
}
38+
39+
40+
_STR_TO_DTYPE = {value: key for (key, value) in _DTYPE_TO_STR.items()}
41+
42+
43+
def mxfp_dtype_to_str(dtype: MXFPDType) -> str:
44+
try:
45+
return _DTYPE_TO_STR[dtype]
46+
except KeyError as e:
47+
supported = ", ".join(str(dtype) for dtype in _DTYPE_TO_STR)
48+
raise ValueError(
49+
f"Unsupported MXFP dtype {dtype}. Supported dtypes: {supported}"
50+
) from e
51+
52+
53+
def mxfp_str_to_dtype(dtype: str) -> MXFPDType:
54+
try:
55+
return _STR_TO_DTYPE[dtype]
56+
except KeyError as e:
57+
supported = ", ".join(sorted(_STR_TO_DTYPE))
58+
raise ValueError(
59+
f"Unsupported MXFP dtype string {dtype!r}. Supported strings: {supported}"
60+
) from e
61+
62+
1663
def _match_supported_modules(module: torch.nn.Module, _name: str) -> bool:
1764
"""Default filter function that matches supported modules."""
1865
return isinstance(module, torch.nn.Linear)
@@ -23,7 +70,7 @@ def _match_supported_modules(module: torch.nn.Module, _name: str) -> bool:
2370
class MXFPOpConfig(AOBaseConfig):
2471
"""Configuration for Arm MXFP source transforms."""
2572

26-
weight_dtype: torch.dtype = torch.float8_e4m3fn
73+
weight_dtype: MXFPDType = torch.float8_e4m3fn
2774
weight_scaling_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL
2875

2976
# Only block size of 32 is currently supported for now, so we hardcode it here.
@@ -32,11 +79,7 @@ def block_size(self) -> int:
3279
return 32
3380

3481
def __post_init__(self) -> None:
35-
if self.weight_dtype not in (
36-
torch.float4_e2m1fn_x2,
37-
torch.float8_e4m3fn,
38-
torch.float8_e5m2,
39-
):
82+
if self.weight_dtype not in SUPPORTED_MXFP_DTYPES:
4083
raise ValueError(f"Unsupported weight_dtype: {self.weight_dtype}")
4184
if not isinstance(self.weight_scaling_mode, ScaleCalculationMode):
4285
raise ValueError(

0 commit comments

Comments
 (0)