-
Notifications
You must be signed in to change notification settings - Fork 108
Add F.scaled_grouped_mm
#2747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add F.scaled_grouped_mm
#2747
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for torch.nn.functional.scaled_grouped_mm, a PyTorch function for scaled grouped matrix multiplication.
Key changes:
- Implements the
scaled_grouped_mmtorchsymbol with input validation and shape inference - Adds three comprehensive test cases covering 2D×2D and 2D×3D tensor combinations with different scaling types
- Registers the operation in the torch executor with appropriate availability checking
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| thunder/torch/init.py | Adds scaled_grouped_mm function with shape validation, dtype checking, and output shape computation |
| thunder/tests/test_ops.py | Adds test cases for tensorwise and blockwise scaling scenarios with FP8 and MXFP8 dtypes |
| thunder/executors/torchex.py | Registers scaled_grouped_mm operation and implements checker function for executor |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
thunder/tests/test_ops.py:1
- Both test_scaled_grouped_mm_3d2d_rowwise and test_scaled_grouped_mm_2d3d_rowwise test the same 2D @ 3D case (mat_a is 2D, mat_b after transpose is 3D). There is no test coverage for the 3D @ 2D case where mat_a would be 3D with shape (groups, m, k) and mat_b would be 2D with shape (k, n). Consider adding a test for this case or modifying one of the existing tests to cover it.
from collections.abc import Callable
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
74c59d3 to
f428a90
Compare
What does this PR do?
As per title, adds https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_grouped_mm.html