Skip to content

Conversation

@mayuyuace
Copy link
Collaborator

@mayuyuace mayuyuace commented Nov 12, 2025

Support wfp8a16.
And only support per tensor scale.

Now cutlass-sycl uses inefficient data type convert function from fp8 to other dtype:
https://github.com/intel/sycl-tla/blob/5ac97000432951b2812ade99235fbf539958c5d4/include/cutlass/float8.h#L336
image
As the picture shows, cuda will uses hardware instruction.
So the fp8 grouped gemm kernel has low performance.
Cutlass needs provide more efficient convert function.

Update:
Optimize dequantize fp8 by replacing native dtype convert function:
For example when m=1, n 5120, k=8192, e=16, topk=1 on PVC:
fp16 @ e5m2: 6.321ms --> 1.358ms
bf16 @ e5m2: 6.319ms --> 3.931ms
fp16 @ e4m3: 7.007ms --> 2.648ms
bf16 @ e4m3: 7.109ms --> 5.072ms

Signed-off-by: mayuyuace <[email protected]>
Copilot AI review requested due to automatic review settings November 12, 2025 06:07
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for FP8 (float8) data types in grouped GEMM operations, specifically enabling weight quantization with FP8 formats (e4m3 and e5m2) while keeping activations in FP16/BF16. The implementation introduces a scaling parameter for dequantization of FP8 weights during computation.

Key Changes:

  • Added scale_B parameter throughout the grouped GEMM call chain for FP8 weight dequantization
  • Implemented four new FP8 policy classes supporting e4m3/e5m2 formats with fp16/bf16 activations
  • Added comprehensive test coverage for FP8 grouped GEMM operations

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
vllm_xpu_kernels/fused_moe_interface.py Added scale_B parameter to cutlass_grouped_gemm and xpu_fused_moe function signatures
tests/fused_moe/test_fused_moe.py Added new test_grouped_gemm_fp8 test case and updated existing tests to pass None for scale_B
csrc/xpu/torch_bindings.cpp Updated PyTorch binding signature to include optional ptr_B_scale parameter
csrc/xpu/ops.h Added ptr_B_scale parameter to cutlass_grouped_gemm function declaration
csrc/xpu/cutlass_kernels/grouped_gemm_kernel.cpp Integrated scale_B parameter into kernel execution and added FP8 policy template instantiations
csrc/xpu/cutlass_kernels/grouped_gemm.hpp Added FP8 data type detection and routing logic with new policy-based kernel dispatching
csrc/xpu/cutlass_kernels/collective/gemm/moe_gemm_array_cooperative.hpp Updated to handle scale tensor in mainloop tensors tuple
csrc/xpu/cutlass_kernels/collective/gemm/moe_dtype_policy.hpp Added four new FP8 policy classes and GmemTiledCopy definitions for all policies
csrc/xpu/cutlass_kernels/collective/gemm/moe_array_mma.hpp Implemented FP8 conversion logic and scale application in the MMA collective

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: mayuyuace <[email protected]>
@mayuyuace
Copy link
Collaborator Author

@Liangliang-Ma Please help review this.

Signed-off-by: mayuyuace <[email protected]>
@mayuyuace
Copy link
Collaborator Author

@jikunshang @baodii

@jikunshang
Copy link
Collaborator

I think we can on hold this now that we will upgrade sycl-tla v0.6 which may break current fp16/bf16 moe gemm..

@mayuyuace
Copy link
Collaborator Author

mayuyuace commented Nov 13, 2025

I think we can on hold this now that we will upgrade sycl-tla v0.6 which may break current fp16/bf16 moe gemm..

Sure that this PR can be merged after sycl-tla is upgraded.

Signed-off-by: mayuyuace <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants