88from typing import Any , cast , Sequence , Set , Type
99
1010import torch
11- from executorch .backends .arm ._passes import ArmPass
11+ from executorch .backends .arm ._passes import ArmOpTargetedPass
1212from executorch .backends .arm ._passes .arm_pass_utils import (
1313 create_node ,
1414 get_first_fake_tensor ,
1515)
16+ from executorch .backends .arm .ao_ext .mxfp import (
17+ mxfp_dtype_to_str ,
18+ mxfp_str_to_dtype ,
19+ MXFPDType ,
20+ )
1621from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
1722from executorch .exir .dialects ._ops import ops as exir_ops
1823from executorch .exir .pass_base import ExportPass , PassResult
24+ from torchao .prototype .mx_formats .mx_tensor import DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2
1925
2026
21- def _get_block_scaled_payload_dtype (qdata : torch .Tensor ) -> torch .dtype :
27+ def _get_weights_payload_dtype (
28+ qdata_node : torch .fx .Node ,
29+ dtype : str = "" ,
30+ ) -> MXFPDType :
31+ if dtype :
32+ return mxfp_str_to_dtype (dtype )
33+ qdata = get_first_fake_tensor (qdata_node )
2234 if qdata .dtype == torch .uint8 :
2335 return torch .float4_e2m1fn_x2
2436 return qdata .dtype
2537
2638
27- def _mark_fp4_payload (node : torch .fx .Node , payload_dtype : torch .dtype ) -> None :
39+ def _mark_mxfp_payload (node : torch .fx .Node , payload_dtype : MXFPDType ) -> None :
40+ """Annotate uint8-backed MXFP payload nodes with their TOSA dtype.
41+
42+ PyTorch represents sub-byte MXFP payloads as ``torch.uint8`` tensors, so
43+ the tensor dtype alone cannot distinguish FP4E2M1, FP6E2M3, and FP6E3M2.
44+ Store the logical TOSA dtype in node metadata so later lowering and
45+ serialization treat the payload as MXFP data rather than ordinary uint8.
46+ FP8 payloads have native PyTorch dtypes and do not need this metadata.
47+
48+ """
2849 if payload_dtype == torch .float4_e2m1fn_x2 :
2950 node .meta [TosaSpecialDtype .meta_key ()] = TosaSpecialDtype .FP4E2M1
51+ elif payload_dtype == DTYPE_FP6_E2M3 :
52+ node .meta [TosaSpecialDtype .meta_key ()] = TosaSpecialDtype .FP6E2M3
53+ elif payload_dtype == DTYPE_FP6_E3M2 :
54+ node .meta [TosaSpecialDtype .meta_key ()] = TosaSpecialDtype .FP6E3M2
3055
3156
32- class RewriteMXFPLinearPass (ArmPass ):
57+ class RewriteMXFPLinearPass (ArmOpTargetedPass ):
3358 """Rewrite ``tosa_mxfp.linear`` into explicit TOSA MXFP operators.
3459
3560 For each MXFP linear custom op, the pass:
@@ -44,15 +69,24 @@ class RewriteMXFPLinearPass(ArmPass):
4469
4570 """
4671
72+ target_ops = {
73+ torch .ops .tosa_mxfp .linear .default ,
74+ exir_ops .edge .tosa_mxfp .linear .default ,
75+ }
4776 _passes_required_after : Set [Type [ExportPass ]] = set ()
4877
4978 def __init__ (self , exported_program : torch .export .ExportedProgram , * args , ** kwargs ):
5079 super ().__init__ (* args , ** kwargs )
5180 self .exported_program = exported_program
5281
53- def _get_linear_args (
54- self , node : torch .fx .Node
55- ) -> tuple [torch .fx .Node , torch .fx .Node , torch .fx .Node , torch .fx .Node | None , int ]:
82+ def _get_linear_args (self , node : torch .fx .Node ) -> tuple [
83+ torch .fx .Node ,
84+ torch .fx .Node ,
85+ torch .fx .Node ,
86+ torch .fx .Node | None ,
87+ int ,
88+ MXFPDType ,
89+ ]:
5690 """Extract the MXFP linear operands from a custom-op node."""
5791 input_node = cast (torch .fx .Node , node .args [0 ])
5892 weight_qdata_node = cast (torch .fx .Node , node .args [1 ])
@@ -65,7 +99,26 @@ def _get_linear_args(
6599 int ,
66100 node .args [4 ] if len (node .args ) > 4 else node .kwargs .get ("block_size" , 32 ),
67101 )
68- return input_node , weight_qdata_node , weight_scale_node , bias_node , block_size
102+ payload_dtype_str = cast (
103+ str ,
104+ (
105+ node .args [5 ]
106+ if len (node .args ) > 5
107+ else node .kwargs .get (
108+ "weight_payload_dtype" ,
109+ node .kwargs .get ("weight_dtype" , "" ),
110+ )
111+ ),
112+ )
113+ payload_dtype = _get_weights_payload_dtype (weight_qdata_node , payload_dtype_str )
114+ return (
115+ input_node ,
116+ weight_qdata_node ,
117+ weight_scale_node ,
118+ bias_node ,
119+ block_size ,
120+ payload_dtype ,
121+ )
69122
70123 def _reshape_with_view (
71124 self ,
@@ -96,14 +149,15 @@ def _create_block_scaled_inputs(
96149 weight_qdata_node : torch .fx .Node ,
97150 weight_scale_node : torch .fx .Node ,
98151 block_size : int ,
152+ payload_dtype : MXFPDType ,
99153 ) -> tuple [torch .fx .Node , torch .fx .Node ]:
100154 """Create rank-3 inputs for the block-scaled cast and matmul ops."""
101155 graph = graph_module .graph
102156 input_fake = get_first_fake_tensor (input_node )
103157 weight_qdata_fake = get_first_fake_tensor (weight_qdata_node )
104158 weight_scale_fake = get_first_fake_tensor (weight_scale_node )
105- weight_dtype = _get_block_scaled_payload_dtype ( weight_qdata_fake )
106- _mark_fp4_payload (weight_qdata_node , weight_dtype )
159+ payload_dtype_str = mxfp_dtype_to_str ( payload_dtype )
160+ _mark_mxfp_payload (weight_qdata_node , payload_dtype )
107161
108162 batches = reduce (operator .mul , input_fake .shape [:- 1 ], 1 )
109163 input_reshape_shape = [1 , batches , input_fake .shape [- 1 ]]
@@ -123,13 +177,13 @@ def _create_block_scaled_inputs(
123177 graph = graph ,
124178 op_target = exir_ops .backend .tosa .CAST_TO_BLOCK_SCALED .default ,
125179 args = (input_reshaped , block_size ),
126- kwargs = {"output_dtype" : weight_dtype },
180+ kwargs = {"output_dtype" : payload_dtype_str },
127181 from_node = mxfp_linear_node ,
128182 )
129183 cast_node .meta ["val" ] = exir_ops .backend .tosa .CAST_TO_BLOCK_SCALED .default (
130184 get_first_fake_tensor (input_reshaped ),
131185 block_size ,
132- output_dtype = weight_dtype ,
186+ output_dtype = payload_dtype_str ,
133187 )
134188
135189 input_qdata_node = create_node (
@@ -140,7 +194,7 @@ def _create_block_scaled_inputs(
140194 from_node = mxfp_linear_node ,
141195 )
142196 input_qdata_node .meta ["val" ] = cast_node .meta ["val" ][0 ]
143- _mark_fp4_payload (input_qdata_node , weight_dtype )
197+ _mark_mxfp_payload (input_qdata_node , payload_dtype )
144198
145199 input_scale_node = create_node (
146200 graph = graph ,
@@ -165,8 +219,10 @@ def _create_matmul_node(
165219 weight_qdata_node : torch .fx .Node ,
166220 weight_scale_node : torch .fx .Node ,
167221 block_size : int ,
222+ payload_dtype : MXFPDType ,
168223 ) -> torch .fx .Node :
169224 """Insert ``MATMUL_T_BLOCK_SCALED`` with updated fake metadata."""
225+ payload_dtype_str = mxfp_dtype_to_str (payload_dtype )
170226 matmul_node = create_node (
171227 graph = graph_module .graph ,
172228 op_target = exir_ops .backend .tosa .MATMUL_T_BLOCK_SCALED .default ,
@@ -177,7 +233,7 @@ def _create_matmul_node(
177233 weight_scale_node ,
178234 block_size ,
179235 ),
180- kwargs = {},
236+ kwargs = {"payload_dtype" : payload_dtype_str },
181237 from_node = mxfp_linear_node ,
182238 )
183239 matmul_node .meta ["val" ] = exir_ops .backend .tosa .MATMUL_T_BLOCK_SCALED .default (
@@ -186,6 +242,7 @@ def _create_matmul_node(
186242 get_first_fake_tensor (weight_qdata_node ),
187243 get_first_fake_tensor (weight_scale_node ),
188244 block_size ,
245+ payload_dtype = payload_dtype_str ,
189246 )
190247 return matmul_node
191248
@@ -270,6 +327,7 @@ def _rewrite_mxfp_linear_node(
270327 weight_scale_node ,
271328 bias_node ,
272329 block_size ,
330+ payload_dtype ,
273331 ) = self ._get_linear_args (mxfp_linear_node )
274332
275333 with graph .inserting_before (mxfp_linear_node ):
@@ -283,6 +341,7 @@ def _rewrite_mxfp_linear_node(
283341 weight_qdata_node ,
284342 weight_scale_node ,
285343 block_size ,
344+ payload_dtype ,
286345 )
287346 matmul_node = self ._create_matmul_node (
288347 graph_module ,
@@ -292,6 +351,7 @@ def _rewrite_mxfp_linear_node(
292351 weight_qdata_node ,
293352 weight_scale_node ,
294353 block_size ,
354+ payload_dtype ,
295355 )
296356
297357 with graph .inserting_after (matmul_node ):
@@ -314,10 +374,7 @@ def call(self, graph_module: torch.fx.GraphModule):
314374 graph = graph_module .graph
315375
316376 for node in list (graph .nodes ):
317- if node .op != "call_function" or node .target not in (
318- torch .ops .tosa_mxfp .linear .default ,
319- exir_ops .edge .tosa_mxfp .linear .default ,
320- ):
377+ if node .op != "call_function" or node .target not in self .target_ops :
321378 continue
322379
323380 modified = True
0 commit comments