-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[WIP] Fix RMS and test MoE for batch invariance [4/n] #26136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Bram Wasti <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant improvements for batch invariance, particularly for RMSNorm and MoE layers, and strengthens the e2e testing framework. The addition of a batch-invariant Triton kernel for RMSNorm and the corresponding Python overrides are key changes. The updates to the testing logic, making it more rigorous by using randomized batch sizes and prompts, are a great enhancement. The modifications for the FlashInfer backend to ensure deterministic behavior under batch invariance are also well-implemented. However, I've identified a critical issue in the fused_add_rms_norm
implementation and a high-severity issue in the MoE softmax kernel that could undermine the goal of batch invariance.
btw, I'm counting on some RMS test to run here. this shouldn't land until that passes |
f867388
to
648ce25
Compare
Signed-off-by: Bram Wasti <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, I'm counting on some RMS test to run here. this shouldn't land until that passes
Sounds great, please change the title once it is ready
This change gets Qwen/Qwen3-30B-A3B on 1GPU working at bitwise parity across batch sizes.
It also improves the e2e testing logic (makes it more harsh). Please only look at the last commit (it is rebased on #25769)
Purpose
Add RMS kernel directly and override in Python. The change to the csrc version doesn't provide full coverage, but does increase invariance.
Test Plan
Test Result
Pass
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.