Skip to content

Support addmm and split_copy in the Metal (AOTI) backend#19924

Open
abdelaziz-mahdy wants to merge 2 commits into
pytorch:mainfrom
abdelaziz-mahdy:metal-addmm-split-fix
Open

Support addmm and split_copy in the Metal (AOTI) backend#19924
abdelaziz-mahdy wants to merge 2 commits into
pytorch:mainfrom
abdelaziz-mahdy:metal-addmm-split-fix

Conversation

@abdelaziz-mahdy
Copy link
Copy Markdown
Contributor

Summary

The experimental Metal (AOTI) backend can't lower common CNNs (MobileNetV3, YOLO) — export fails with "missing fallback kernels". Fixes #19907.

Two ops hit unsupported AOTI fallbacks:

  • aten::split_copy.Tensor falls back to the proxy executor, which the AOTI runtime doesn't support.
  • aoti_torch_mps_addmm_out is emitted by inductor's mm + bias → addmm fusion (torch/_inductor/fx_passes/post_grad.py) during MPS codegen, but the libtorch-free runtime has no shim for it. Graph-level decomposition alone is insufficient: inductor re-fuses mm + bias back into addmm, and for batch=1 it folds the size-1 unsqueeze that DecomposeLinearPass inserts, so the fusion fires anyway.

Changes

  • ReplaceViewCopyWithViewPass: map split_copy.Tensor → split.Tensor (core + edge dialect), mirroring the existing slice_copy/select_copy handling, so inductor codegens it as views instead of a proxy-executor fallback.
  • New runtime/ops/op_addmm.mm: implements aoti_torch_mps_addmm_out (out = beta·self + alpha·(mat1 @ mat2)) via MPSGraph, mirroring op_mm.mm (transposed-mat2 handling + graph cache; cache key includes beta/alpha). Registered in CMakeLists.txt.
  • MetalBackend.get_supported_fallback_kernels: allow-list aoti_torch_mps_addmm_out.
  • tests/test_modules.py: add addmm, split_cat, and linear_bias_batch1 (batch=1 → the MobileNet-classifier case) regression modules.

Test plan

Built executor_runner with -DEXECUTORCH_BUILD_METAL=ON on macOS arm64 (Apple silicon) and ran the exported .ptes (input = ones); runtime outputs match eager:

model runtime eager
addmm (batch=1) 2.745056 2.7451
linear_bias_batch1 (all 101 values) -3.4252 -3.4252
split_cat 40.0 40.0
MobileNetV3-small runs, [1, 1000]

AOT export of MobileNetV3-small / MobileNetV2 / a YOLO-style head no longer raises "missing fallback kernels".

The experimental Metal backend could not lower common CNNs (MobileNet,
YOLO) because two ops hit unsupported AOTI fallback kernels:

- aten::split_copy.Tensor fell back to the proxy executor, which the AOTI
  runtime does not support.
- aoti_torch_mps_addmm_out is emitted by inductor's mm+bias fusion
  (torch/_inductor/fx_passes/post_grad.py) during MPS codegen, but the
  libtorch-free runtime had no shim for it. Graph-level decomposition is
  insufficient because inductor re-fuses mm+bias back into addmm (and folds
  the size-1 unsqueeze that DecomposeLinearPass inserts for batch=1).

Changes:
- Map split_copy.Tensor -> split.Tensor in ReplaceViewCopyWithViewPass so
  inductor codegens it as views (like the existing slice_copy/select_copy).
- Implement aoti_torch_mps_addmm_out (op_addmm.mm) via MPSGraph, mirroring
  op_mm.mm, and allow-list it in get_supported_fallback_kernels.
- Add regression modules (addmm, split_cat, batch-1 linear) to test_modules.

Verified end-to-end with executor_runner on macOS arm64: MobileNetV3-small,
plus addmm / split_cat / linear(batch=1) numerics matching eager.
Copilot AI review requested due to automatic review settings June 1, 2026 23:23
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jun 1, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19924

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 15 Awaiting Approval

As of commit 7dc3552 with merge base 40b0a35 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 1, 2026
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Jun 1, 2026

CLA Signed
The committers listed above are authorized under a signed CLA.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds Metal backend support for aten.addmm via a new c-shim that wraps an MPSGraph-cached addmm implementation, registers the kernel as a supported fallback, and extends the view-copy replacement pass to also rewrite split_copy (which has no c-shim) into the native split view op. Tests are added for both addmm and split+cat module patterns.

Changes:

  • New aoti_torch_mps_addmm_out Metal shim with graph caching, transposed-mat2 detection, and bias broadcasting.
  • Registration of the addmm shim in the backend's supported fallback kernel list and CMake build.
  • Extended replace_view_copy_with_view to map split_copysplit, plus new Addmm, LinearWithBias (batch=1), and SplitCat test modules.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
backends/apple/metal/runtime/ops/op_addmm.mm New MPSGraph-backed addmm c-shim implementation.
backends/apple/metal/metal_backend.py Registers aoti_torch_mps_addmm_out as a supported fallback kernel.
backends/apple/metal/CMakeLists.txt Adds op_addmm.mm to the Metal AOTI sources.
backends/aoti/passes/replace_view_copy_with_view.py Maps aten.split_copy.Tensor to aten.split.Tensor (both torch and edge variants).
backends/apple/metal/tests/test_modules.py Adds Addmm, SplitCat, and linear_bias_batch1 test registry entries.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +125 to +130
// Bias may be 1-D [N] or 2-D [M, N]; feed its physical shape and rely on
// MPSGraph broadcasting in the addition.
NSMutableArray<NSNumber*>* biasShape = [NSMutableArray array];
for (size_t i = 0; i < static_cast<size_t>(bias_tensor->dim()); ++i) {
[biasShape addObject:@(bias_tensor->sizes()[i])];
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — now keying on the full bias shape (rank + each dim) instead of just the rank. Fixed in 7dc3552.

Comment on lines +251 to +253
@try {
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
} @catch (NSException* exception) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Releases now run in an @finally, so the throw path frees them too. 7dc3552

Comment on lines +259 to +265
throw std::runtime_error("MPSGraph execution failed with NSException");
}

[mat1Data release];
[mat2Data release];
[biasData release];
[outputData release];
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same fix as above — moved the releases into the @finally. 7dc3552

Comment on lines +56 to +57
// Validate matmul operand dimensions.
if (mat1_tensor->dim() != 2 || mat2_tensor->dim() != 2) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left as-is for parity with op_mm (AOTInductor always allocates out at the right shape), but I added the dtype check across operands which covers the most likely mismatch here.

Comment on lines +110 to +119
int32_t dtype = static_cast<int32_t>(mat1_tensor->scalar_type());
MPSDataType mps_dtype;
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
mps_dtype = MPSDataTypeFloat32;
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
mps_dtype = MPSDataTypeBFloat16;
} else {
ET_LOG(Error, "aoti_torch_mps_addmm_out: unsupported dtype %d", dtype);
throw std::runtime_error("Unsupported data type for addmm");
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added — verifies mat2/self/out all match mat1 dtype before building the graph. 7dc3552

Comment on lines +57 to +64
if (mat1_tensor->dim() != 2 || mat2_tensor->dim() != 2) {
std::string error_msg =
"aoti_torch_mps_addmm_out: mat1/mat2 must be 2-D, got " +
std::to_string(mat1_tensor->dim()) + " and " +
std::to_string(mat2_tensor->dim());
ET_LOG(Error, "%s", error_msg.c_str());
throw std::runtime_error(error_msg);
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept consistent with the sibling shims (op_mm/op_bmm validate the same way) so addmm does not diverge.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 1, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

- Cache key: include the full bias shape (rank + each dim), not just the
  rank, so equal-rank but differently-shaped biases (e.g. [N] vs [1], or
  [M, N] vs [1, N]) don't collide and reuse a graph whose biasPlaceholder
  has the wrong shape.
- Release the MPSGraphTensorData objects in an @finally so they aren't
  leaked when executeMPSGraph throws.
- Validate that mat2/self/out share mat1's dtype before building the graph
  (return InvalidArgument on mismatch) to avoid silently reinterpreting
  buffers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Metal backend (AOTI): MobileNet/YOLO fail to export — missing fallback c-shims (addmm / split_copy / slice_copy)

2 participants