Skip to content

Commit 7aff018

Browse files
author
ssjia
committed
Update
[ghstack-poisoned]
2 parents c770b91 + 9aab737 commit 7aff018

101 files changed

Lines changed: 6003 additions & 1974 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pull.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,8 @@ jobs:
852852
strategy:
853853
matrix:
854854
dtype: [fp32]
855-
pt2e_quantize: [qnn_16a16w, qnn_8a8w]
855+
# TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved
856+
pt2e_quantize: [qnn_8a8w]
856857
mode: [qnn]
857858
fail-fast: false
858859
with:

.github/workflows/trunk.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,8 @@ jobs:
951951
strategy:
952952
matrix:
953953
dtype: [fp32]
954-
pt2e_quantize: [qnn_16a16w, qnn_8a8w]
954+
# TODO(T12345): re-enable qnn_16a16w once OOM on linux.2xlarge is resolved
955+
pt2e_quantize: [qnn_8a8w]
955956
mode: [qnn]
956957
fail-fast: false
957958
with:

Makefile

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ help:
127127
@echo " llava-cpu - Build Llava runner with CPU backend"
128128
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
129129
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
130-
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
131-
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
130+
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner and worker with CUDA backend"
131+
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner and worker with MLX backend"
132132
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
133133
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
134134
@echo " qwen3_5_moe-mlx - Build Qwen3.5 MoE runner with MLX backend"
@@ -444,20 +444,23 @@ qwen3_5_moe-cuda:
444444
gemma4_31b-cuda:
445445
@echo "==> Building and installing ExecuTorch with CUDA..."
446446
cmake --workflow --preset llm-release-cuda
447-
@echo "==> Building Gemma 4 31B runner with CUDA..."
447+
@echo "==> Building Gemma 4 31B runner, worker, and no-bleed test with CUDA..."
448448
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-cuda
449449
@echo ""
450450
@echo "✓ Build complete!"
451-
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
451+
@echo " Runner: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
452+
@echo " Worker: cmake-out/examples/models/gemma4_31b/gemma4_31b_worker"
453+
@echo " Test: cmake-out/examples/models/gemma4_31b/test_gemma4_31b_nobleed"
452454

453455
gemma4_31b-mlx:
454456
@echo "==> Building and installing ExecuTorch with MLX..."
455457
cmake --workflow --preset mlx-release
456-
@echo "==> Building Gemma 4 31B runner with MLX..."
458+
@echo "==> Building Gemma 4 31B runner and worker with MLX..."
457459
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx
458460
@echo ""
459461
@echo "✓ Build complete!"
460-
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
462+
@echo " Runner: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"
463+
@echo " Worker: cmake-out/examples/models/gemma4_31b/gemma4_31b_worker"
461464

462465
qwen3_5_moe-metal:
463466
@echo "==> Building and installing ExecuTorch with Metal..."

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _tosa_pipeline(
618618
RewriteMatmulPass(),
619619
RewritePadPass(),
620620
FuseViewCopyTransformPass(),
621-
RemovePermutesAroundElementwiseTosaOps(),
621+
RemovePermutesAroundElementwiseTosaOps(exported_program),
622622
CanonicalizeViewCopyPermutePass(),
623623
FuseCascadedTransposeOrPermuteOps(),
624624
RewriteHighRankSingletonPermutePass(),

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
8+
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
69
from executorch.backends.arm._passes.insert_table_ops import TableOps
710
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
811
RemovePermutesAroundElementwiseOps,
912
)
13+
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115

1216

1317
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
14-
def __init__(self) -> None:
18+
def __init__(self, exported_program: ExportedProgram) -> None:
1519
super().__init__(
1620
extra_permutable_ops={
1721
*TableOps.unary_table_ops.keys(),
@@ -20,16 +24,19 @@ def __init__(self) -> None:
2024
exir_ops.backend.tosa.TABLE.default,
2125
}
2226
)
27+
self.exported_program = exported_program
28+
29+
def _is_constant(self, node: torch.fx.Node) -> bool:
30+
# Override fragile string match check with exported program check
31+
return super()._is_constant(node) or is_param_node(self.exported_program, node)
2332

2433
def permute_subgraph(self, subgraph) -> bool:
25-
# Original function will always permute constant nodes which is wrong for table ops
26-
# Remove constant tosa.TABLE edges before running full function
34+
# TABLE lookup inputs are already tied to the table layout.
2735
new_constant_edges_in = set()
2836
for const_node, user_node in subgraph.constant_edges_in:
2937
if user_node.target == exir_ops.backend.tosa.TABLE.default:
3038
continue
31-
else:
32-
new_constant_edges_in.add((const_node, user_node))
39+
new_constant_edges_in.add((const_node, user_node))
3340

3441
subgraph.constant_edges_in = new_constant_edges_in
3542
return super().permute_subgraph(subgraph)

backends/arm/ao_ext/ops/mxfp_conv2d_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
)
3333

3434

35+
_SUPPORTED_OUTPUT_DTYPES: set[torch.dtype] = {
36+
torch.float32,
37+
torch.bfloat16,
38+
}
39+
40+
3541
def _get_mx_elem_dtype(
3642
weight_qdata: torch.Tensor,
3743
weight_payload_dtype: str = "",
@@ -208,10 +214,12 @@ def __init__(
208214
groups: int,
209215
weight_dtype: MXFPDType,
210216
block_size: int,
217+
output_dtype: torch.dtype = torch.float32,
211218
) -> None:
212219
super().__init__()
213220
self.weight_dtype = mxfp_dtype_to_str(weight_dtype)
214221
self.block_size = block_size
222+
self.output_dtype = output_dtype
215223

216224
self.register_buffer("weight_qdata", weight_qdata, persistent=True)
217225
self.register_buffer("weight_scale", weight_scale, persistent=True)
@@ -233,7 +241,7 @@ def __init__(
233241
self.groups = groups
234242

235243
def forward(self, x: torch.Tensor) -> torch.Tensor:
236-
return torch.ops.tosa_mxfp.conv2d.default(
244+
output = torch.ops.tosa_mxfp.conv2d.default(
237245
x,
238246
self.weight_qdata,
239247
self.weight_scale,
@@ -245,6 +253,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
245253
self.block_size,
246254
self.weight_dtype,
247255
)
256+
if self.output_dtype != torch.float32:
257+
output = output.to(self.output_dtype)
258+
return output
248259

249260

250261
def transform_conv2d_to_mxfp(
@@ -276,6 +287,9 @@ def transform_conv2d_to_mxfp(
276287
)
277288

278289
bias = module.bias.detach().to(torch.float32) if module.bias is not None else None
290+
output_dtype = weight_ohwi.dtype
291+
if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
292+
raise ValueError(f"Unsupported output_dtype: {output_dtype}")
279293
return MXFPConv2dOp(
280294
weight_qdata,
281295
weight_scale,
@@ -286,4 +300,5 @@ def transform_conv2d_to_mxfp(
286300
module.groups,
287301
config.weight_dtype,
288302
config.block_size,
303+
output_dtype,
289304
)

backends/arm/test/misc/test_mxfp_conv2d_ao.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,61 @@ def test_mxfp_conv2d_quantize_supports_fp4_weights() -> None:
159159
)
160160

161161

162+
def test_mxfp_conv2d_preserves_bfloat16_output_dtype() -> None:
163+
model = Conv2dModule().eval().to(torch.bfloat16)
164+
to_mxfp(
165+
model,
166+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
167+
)
168+
169+
output = model(torch.randn(1, IN_CHANNELS, 8, 8, dtype=torch.bfloat16))
170+
171+
assert isinstance(model.conv, MXFPConv2dOp)
172+
assert model.conv.output_dtype == torch.bfloat16
173+
assert output.dtype == torch.bfloat16
174+
175+
176+
def test_mxfp_conv2d_op_output_dtype_constructor_arg() -> None:
177+
model = Conv2dModule().eval()
178+
config = MXFPOpConfig(weight_dtype=torch.float8_e4m3fn)
179+
to_mxfp(
180+
model,
181+
config,
182+
)
183+
assert isinstance(model.conv, MXFPConv2dOp)
184+
185+
fp32_conv = MXFPConv2dOp(
186+
model.conv.weight_qdata,
187+
model.conv.weight_scale,
188+
model.conv.bias,
189+
model.conv.stride,
190+
model.conv.padding,
191+
model.conv.dilation,
192+
model.conv.groups,
193+
config.weight_dtype,
194+
config.block_size,
195+
)
196+
bf16_conv = MXFPConv2dOp(
197+
model.conv.weight_qdata,
198+
model.conv.weight_scale,
199+
model.conv.bias,
200+
model.conv.stride,
201+
model.conv.padding,
202+
model.conv.dilation,
203+
model.conv.groups,
204+
config.weight_dtype,
205+
config.block_size,
206+
output_dtype=torch.bfloat16,
207+
)
208+
209+
test_input = torch.randn(1, IN_CHANNELS, 8, 8)
210+
211+
assert fp32_conv.output_dtype == torch.float32
212+
assert fp32_conv(test_input).dtype == torch.float32
213+
assert bf16_conv.output_dtype == torch.bfloat16
214+
assert bf16_conv(test_input).dtype == torch.bfloat16
215+
216+
162217
def _test_mxfp_conv2d_export_preserves_custom_op(config: MXFPOpConfig) -> None:
163218
model = Conv2dModule().eval()
164219
to_mxfp(model, config)
@@ -198,6 +253,33 @@ def test_mxfp6_e3m2_conv2d_export_preserves_custom_op() -> None:
198253
)
199254

200255

256+
def test_mxfp_conv2d_export_preserves_inferred_bfloat16_output_dtype() -> None:
257+
model = Conv2dModule().eval().to(torch.bfloat16)
258+
to_mxfp(
259+
model,
260+
MXFPOpConfig(weight_dtype=torch.float8_e4m3fn),
261+
)
262+
263+
exported = export(
264+
model,
265+
(torch.randn(1, IN_CHANNELS, 8, 8, dtype=torch.bfloat16),),
266+
strict=False,
267+
)
268+
269+
cast_nodes = [
270+
node
271+
for node in exported.graph_module.graph.nodes
272+
if node.op == "call_function" and node.target == torch.ops.aten.to.dtype
273+
]
274+
275+
assert len(cast_nodes) == 1
276+
assert cast_nodes[0].args[1] == torch.bfloat16
277+
assert cast_nodes[0].meta["val"].dtype == torch.bfloat16
278+
cast_input = cast_nodes[0].args[0]
279+
assert isinstance(cast_input, torch.fx.Node)
280+
assert cast_input.target == torch.ops.tosa_mxfp.conv2d.default
281+
282+
201283
def test_mxfp_conv2d_cpu_impl_matches_ref() -> None:
202284
ref_model = Conv2dModule().eval()
203285
test_model = Conv2dModule().eval()

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def forward(self, x: torch.Tensor):
453453
Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 3
454454
),
455455
"model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase(
456-
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4
456+
Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 2
457457
),
458458
"model_6_gru_linear": TransposeCountCase(
459459
Model6GruLinear(), (torch.randn(2, 16, 8),), 2

0 commit comments

Comments
 (0)