-
Notifications
You must be signed in to change notification settings - Fork 14
support fp8 grouped gemm #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: mayuyuace <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for FP8 (float8) data types in grouped GEMM operations, specifically enabling weight quantization with FP8 formats (e4m3 and e5m2) while keeping activations in FP16/BF16. The implementation introduces a scaling parameter for dequantization of FP8 weights during computation.
Key Changes:
- Added
scale_Bparameter throughout the grouped GEMM call chain for FP8 weight dequantization - Implemented four new FP8 policy classes supporting e4m3/e5m2 formats with fp16/bf16 activations
- Added comprehensive test coverage for FP8 grouped GEMM operations
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/fused_moe_interface.py | Added scale_B parameter to cutlass_grouped_gemm and xpu_fused_moe function signatures |
| tests/fused_moe/test_fused_moe.py | Added new test_grouped_gemm_fp8 test case and updated existing tests to pass None for scale_B |
| csrc/xpu/torch_bindings.cpp | Updated PyTorch binding signature to include optional ptr_B_scale parameter |
| csrc/xpu/ops.h | Added ptr_B_scale parameter to cutlass_grouped_gemm function declaration |
| csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp | Integrated scale_B parameter into kernel execution and added FP8 policy template instantiations |
| csrc/xpu/cutlass_kernels/grouped_gemm.hpp | Added FP8 data type detection and routing logic with new policy-based kernel dispatching |
| csrc/xpu/cutlass_kernels/collective/gemm/moe_gemm_array_cooperative.hpp | Updated to handle scale tensor in mainloop tensors tuple |
| csrc/xpu/cutlass_kernels/collective/gemm/moe_dtype_policy.hpp | Added four new FP8 policy classes and GmemTiledCopy definitions for all policies |
| csrc/xpu/cutlass_kernels/collective/gemm/moe_array_mma.hpp | Implemented FP8 conversion logic and scale application in the MMA collective |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: mayuyuace <[email protected]>
|
@Liangliang-Ma Please help review this. |
Signed-off-by: mayuyuace <[email protected]>
|
I think we can on hold this now that we will upgrade sycl-tla v0.6 which may break current fp16/bf16 moe gemm.. |
Sure that this PR can be merged after sycl-tla is upgraded. |
Signed-off-by: mayuyuace <[email protected]>
Support wfp8a16.
And only support per tensor scale.
Now cutlass-sycl uses inefficient data type convert function from fp8 to other dtype:

https://github.com/intel/sycl-tla/blob/5ac97000432951b2812ade99235fbf539958c5d4/include/cutlass/float8.h#L336
As the picture shows, cuda will uses hardware instruction.
So the fp8 grouped gemm kernel has low performance.
Cutlass needs provide more efficient convert function.
Update:
Optimize dequantize fp8 by replacing native dtype convert function:
For example when m=1, n 5120, k=8192, e=16, topk=1 on PVC:
fp16 @ e5m2: 6.321ms --> 1.358ms
bf16 @ e5m2: 6.319ms --> 3.931ms
fp16 @ e4m3: 7.007ms --> 2.648ms
bf16 @ e4m3: 7.109ms --> 5.072ms