Skip to content

XNNPACK: Lift constant mul scalars for partitioning#20515

Open
mansnils wants to merge 2 commits into
pytorch:mainfrom
mansnils:ops
Open

XNNPACK: Lift constant mul scalars for partitioning#20515
mansnils wants to merge 2 commits into
pytorch:mainfrom
mansnils:ops

Conversation

@mansnils

@mansnils mansnils commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

XNNPACK supports the tensor overload for multiply, but plain aten.mul.Scalar is not selected by the XNNPACK partitioner.

Adds a narrow scalar-lifting pass that rewrites aten.mul.Scalar to aten.mul.Tensor by registering the scalar as a small constant buffer. This avoids introducing an aten.full op while allowing the existing multiply lowering path to partition the op.

Keep SDPA scale multipliers as aten.mul.Scalar so ConvertToSDPAPass can still recover the attention scale before replacing the pattern. Add test coverage for that guard.

Allow the XNNPACK tester to pass transform passes through to to_edge_transform_and_lower. This keeps op tests on the same path as existing XNNPACK model tests that already use explicit transform passes.

For DeIT Tiny, this removes 24 portable aten.mul.Scalar nodes and reduces delegate count from 62 to 50. In current local timing checks the latency impact is modest: about 1% faster on both Android SME2 and the aarch64 XNNPACK/KleidiAI NEON-class host runner. These are modest uplifts but may introduce more opportunities for improvements.

cc @GregoryComer @digantdesai @cbilgin @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell @rascani

XNNPACK supports the tensor overload for multiply, but plain
aten.mul.Scalar is not selected by the XNNPACK partitioner.

Adds a narrow scalar-lifting pass that rewrites aten.mul.Scalar to
aten.mul.Tensor by registering the scalar as a small constant buffer.
This avoids introducing an aten.full op while allowing the existing
multiply lowering path to partition the op.

Keep SDPA scale multipliers as aten.mul.Scalar so ConvertToSDPAPass can
still recover the attention scale before replacing the pattern. Add test
coverage for that guard.

Allow the XNNPACK tester to pass transform passes through to
to_edge_transform_and_lower. This keeps op tests on the same path as
existing XNNPACK model tests that already use explicit transform passes.

For DeIT Tiny, this removes 24 portable aten.mul.Scalar nodes and
reduces delegate count from 62 to 50. In current local timing checks
the latency impact is modest: about 1% faster on both Android SME2 and
the aarch64 XNNPACK/KleidiAI NEON-class host runner.
These are modest uplifts but may introduce more opportunities for
improvements.

Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
Change-Id: I83b6ad53925edb72afdf0077b5dbb99b5d9c4648
@mansnils mansnils requested a review from digantdesai as a code owner June 25, 2026 14:20
Copilot AI review requested due to automatic review settings June 25, 2026 14:20
@mansnils mansnils added module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm ciflow/trunk release notes: xnnpack Changes to the XNNPack backend delegate labels Jun 25, 2026
@pytorch-bot

pytorch-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20515

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 4 Unrelated Failures, 2 Unclassified Failures

As of commit 3952854 with merge base 82a605d (image):

NEW FAILURES - The following jobs have failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 25, 2026
@mansnils

Copy link
Copy Markdown
Collaborator Author

@claude review this PR

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @mansnils's task in 9m 47s —— View job


PR Review: XNNPACK — Lift constant mul scalars for partitioning

Reviewed the diff against origin/main. Overall this is a clean, well-scoped change with good test coverage and a thoughtful SDPA guard. A few points worth considering before merge.

Correctness / design

  1. Constant added as module buffer + get_attr rather than a lifted constant placeholderlift_constant_scalar_operands_pass.py:55-76 uses register_buffer + graph.get_attr. In a lifted edge ExportedProgram, the idiomatic way to introduce a constant is create_constant_placeholder in backends/transforms/utils.py:98, which also registers it in exp_program.constants / the graph signature. Because this pass subclasses ExportPass (GraphModule-only call()), it has no access to the ExportedProgram, which is presumably why the buffer route was taken.

    In practice XNNPACK serialization handles raw get_attr nodes (node_visitor.py:577get_param_tensor has a get-attr fallback at utils.py:89-94), so the delegated path works. The risk is the non-delegated path: if a mapped mul.Scalar is not partitioned, the orphaned get_attr has no graph-signature entry and the ExecuTorch emitter may not handle it. The dtype guard (point 3) makes this unlikely for mul, but it's worth confirming the XNNPACK CI job covers a case where the lifted mul falls back to portable. Worth a comment explaining why a buffer + get_attr is safe here.

  2. _feeds_sdpa_qk_bmm passthrough set is narrow and coupled to SDPA internals (lift_constant_scalar_operands_pass.py:85-110). The traversal only walks through expand_copy/view_copy to reach a bmm. If the real SDPA graph has any other op (transpose/permute/to_copy) between the scale mul.Scalar and the QK^T bmm, the guard misses it, the scalar gets lifted, and ConvertToSDPAPass.get_scale (convert_to_sdpa.py:33-52) silently fails its assert scale is not None. The unit test only exercises the canonical pattern from partition/graphs/sdpa.py. Since this generic pass now encodes knowledge of the XNNPACK SDPA pattern, consider deriving the guard from the same sdpa.get_graphs() source, or at least add a test with a pattern that has an intervening op. A false positive (leaving a non-SDPA mul→bmm scalar unlifted) is only a missed optimization, so the asymmetry is in the safe direction — but a false negative breaks SDPA.

Tester change

  1. transform_passes silently dropped when a prebuilt stage is passed (test/tester/tester.py, Tester.to_edge_transform_and_lower). In the else branch (caller supplies to_edge_and_lower_stage), partitioners and edge_compile_config are propagated onto the stage but transform_passes is not. Passing both a stage and transform_passes= would silently ignore the passes. Either propagate it (to_edge_and_lower_stage.transform_passes = transform_passes) for consistency, or document that they're mutually exclusive. Fix this →

Minor

  1. self._modified can be a local (lift_constant_scalar_operands_pass.py:53,113,143,149-150). It's set/reset entirely within call() and never read by helpers, so a local modified would be simpler and avoids surprising state on a reused pass instance. Per CLAUDE.md ("choose simpler, more concise"), worth collapsing.

  2. Production lowering path doesn't pick this up. get_transform_passes() now includes the pass, but the XNNPACK recipe (recipes/xnnpack_recipe_provider.py:103-106, LoweringRecipe) sets no edge_transform_passes, so the standard recipe-based export does not benefit — only callers that explicitly pass get_transform_passes() (tests, size_analysis_tool). If the DeIT delegate-count/latency wins are meant to reach normal users, consider wiring it into the LoweringRecipe. If opt-in is intentional for now, a one-line note in the PR description would help.

  3. 0-d constant tensortorch.tensor(value, ...) (:66) produces a rank-0 tensor relying on mul.Tensor broadcasting. The end-to-end test_fp32_mul_scalar covers this, but it's worth confirming the XNNPACK CI job is green since 0-dim operands have historically been finicky in some backends.

Tests

  • test/passes/test_lift_constant_scalar_operands_pass.py nicely covers lift, unmapped-op passthrough, and the SDPA guard.
  • test_fp32_mul_scalar exercises the full export→serialize→run path. Consider adding a case that asserts a non-canonical SDPA-like graph still recovers scale, to lock in point 2.

Nothing here is blocking on its own; #1 and #2 are the items I'd most want addressed (or explicitly reasoned about in a comment) before merge.
· branch ops

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves XNNPACK partitioning coverage by rewriting constant aten.mul.Scalar occurrences into aten.mul.Tensor with a lifted constant buffer, enabling existing multiply lowering to pick them up without introducing runtime aten.full. It also extends the XNNPACK test harness to allow passing explicit transform passes into to_edge_transform_and_lower, and adds tests to protect SDPA scale handling.

Changes:

  • Add LiftConstantScalarOperandsPass to lift constant scalar operands (currently focused on mul.Scalar) into small tensor buffers to enable XNNPACK partitioning.
  • Enable XNNPACK Tester.to_edge_transform_and_lower(...) to thread transform_passes through to executorch.exir.to_edge_transform_and_lower.
  • Add unit and op-test coverage for scalar lifting and for the SDPA scale guard behavior.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
backends/xnnpack/utils/configs.py Adds the scalar-lifting pass to the default XNNPACK transform pass pipeline.
backends/xnnpack/test/tester/tester.py Extends the XNNPACK tester stage to forward explicit transform_passes into to_edge_transform_and_lower.
backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py New unit tests validating scalar lifting and guarding SDPA scale multipliers.
backends/xnnpack/test/ops/test_multiply.py Adds an op test to ensure mul-scalar gets delegated via the transform-pass path.
backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py Implements the scalar-to-tensor lifting pass with an SDPA-specific preservation guard.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/xnnpack/test/tester/tester.py
Comment thread backends/xnnpack/test/tester/tester.py
Comment thread backends/xnnpack/test/tester/tester.py
Comment thread backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py Outdated
Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
Change-Id: I4ab9a85590762c24f474b935a26f61d90b9a0565
@mansnils

Copy link
Copy Markdown
Collaborator Author

@claude review this PR

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @mansnils's task in 4m 17s —— View job


PR Review: XNNPACK — Lift constant mul scalars for partitioning

Re-reviewed the current branch state (both commits) against the prior review feedback. The Copilot and earlier Claude review items have been addressed cleanly — this is in good shape. Notes below.

Previously-raised items — verified resolved

  • transform_passes type annotation now uses a distinct ExirPassType alias (tester.py:29, used at :82, :163) while torch._export.pass_base.PassType is kept for the pass-manager stages. ✅
  • Stage param name kept as to_edge_and_transform_stage (tester.py:156) — no break for test_linear.py callers — and transform_passes is now propagated onto a caller-supplied stage with a type guard (tester.py:176-182). ✅
  • .transform() varargs — the test now passes a tuple (LiftConstantScalarOperandsPass(),) (test_lift_constant_scalar_operands_pass.py:59), avoiding the list assertion in _program.py:153. ✅
  • self._modified → local modified (lift_constant_scalar_operands_pass.py:114). ✅

Still worth considering (non-blocking)

  1. SDPA guard remains coupled to the exact pattern shape (lift_constant_scalar_operands_pass.py:86-111). _feeds_sdpa_qk_bmm only walks through expand_copy/view_copy to reach a bmm. If a real SDPA graph ever places another op (e.g. transpose/permute/to_copy) between the scale mul.Scalar and the QK^T bmm, the guard misses it → the scalar gets lifted → ConvertToSDPAPass.get_scale (convert_to_sdpa.py:33-52) returns None and trips assert scale is not None at convert_to_sdpa.py:71. A false positive (leaving a non-SDPA mul unlifted) is only a missed optimization — safe direction — but a false negative hard-fails SDPA lowering. test_keeps_sdpa_scale_mul_scalar only exercises the canonical sdpa.get_graphs()[0] pattern. Consider either deriving the passthrough set from the same SDPA source, or adding a regression test with an intervening op so this can't silently regress.

  2. Production recipe path still doesn't pick this up. get_transform_passes() (configs.py:30) includes the pass, but _get_xnnpack_lowering_recipe (xnnpack_recipe_provider.py:103-106) sets no edge_transform_passes, and LoweringRecipe supports them (export/recipe.py:130). So the DeIT delegate-count/latency wins reach only callers that explicitly thread get_transform_passes() (op tests, size-analysis tooling) — not the standard recipe export. If opt-in is intentional for now that's fine, but it'd be worth a one-line note in the PR description; if the wins are meant to reach normal users, wiring edge_transform_passes=[lambda _m, _ep: get_transform_passes()] into the recipe is the missing piece. Fix this →

  3. Non-delegated fallback for the lifted get_attr. _create_constant_node uses register_buffer + graph.get_attr (:69-77) rather than create_constant_placeholder, because ExportPass only sees the GraphModule, not the ExportedProgram/graph signature. test_lifted_mul_scalar_can_emit_without_delegation covers the orphaned-get_attr emit path, which is the right guard — good. The dtype check (:132-139, skip when input.dtype != output.dtype) keeps mixed int/float mul.Scalar out of this path, which is what makes the buffer route safe. A one-line comment near :69 noting why a buffer + get_attr is used instead of create_constant_placeholder (no EP access from ExportPass) would save the next reader the dig.

Minor

  • 0-d constant tensor (:65, torch.tensor(value, ...)) relies on mul.Tensor broadcasting against a rank-0 operand. test_fp32_mul_scalar runs it end-to-end including run_method_and_compare_outputs, so this is covered — just worth confirming the XNNPACK CI job is green, since 0-dim operands have historically been finicky in some backends.

Tests

  • test_lift_constant_scalar_operands_pass.py covers lift, unmapped-op passthrough, non-delegated emit, and the SDPA guard — good spread.
  • test_fp32_mul_scalar exercises the full export→serialize→run path and asserts neither mul.Tensor nor mul.Scalar survives in the top-level graph. 👍

Net: ready to merge once you've decided on item #1 (guard robustness / regression test) and item #2 (recipe opt-in vs. wiring) — both are judgment calls rather than defects.
· branch ops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm release notes: xnnpack Changes to the XNNPack backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants