Skip to content

Add FlashInfer provider for moe_quant_group_gemm across BF16, FP8, MXFP4, and NVFP4#1

Draft
Copilot wants to merge 5 commits into
mainfrom
copilot/add-flashinfer-provider-moe-quant-group-gemm
Draft

Add FlashInfer provider for moe_quant_group_gemm across BF16, FP8, MXFP4, and NVFP4#1
Copilot wants to merge 5 commits into
mainfrom
copilot/add-flashinfer-provider-moe-quant-group-gemm

Conversation

Copy link
Copy Markdown

Copilot AI commented May 21, 2026

moe_quant_group_gemm previously only had the base loop-based implementation that simulated grouped expert GEMM one expert at a time. This change adds a FlashInfer-backed provider that maps the existing MoE dispatch metadata onto single grouped kernel launches for BF16, FP8 per-tensor, FP8 block-scale, MXFP4, and NVFP4 variants.

  • FlashInfer provider

    • Adds projects/micro_perf/vendor_ops/GPU/ops/flashinfer/moe_quant_group_gemm.py
    • Registers a flashinfer vendor implementation for moe_quant_group_gemm
    • Preserves the existing op contract by inheriting behavior from the base op and overriding vendor parsing / tensor setup / run dispatch
  • Supported dtype variants

    • BF16flashinfer.grouped_mm_bf16
    • FP8 per-tensorflashinfer.grouped_mm_fp8(..., alpha=...)
    • FP8 block-scaleflashinfer.gemm.group_gemm_fp8_nt_groupwise
    • MXFP4flashinfer.gemm.group_gemm_mxfp4_nt_groupwise
    • NVFP4flashinfer.gemm.group_gemm_nvfp4_nt_groupwise
  • MoE routing integration

    • Reuses expert_dispatch_token_count / expert_dispatch_token_offset from get_moe_tokens_info()
    • Builds FlashInfer m_indptr / segment_offsets directly from the existing expert dispatch layout
    • Keeps output shape and FLOP accounting aligned with the base op
  • Tensor metadata / allocation

    • Defines variant-specific input tensors for activations, expert weights, and scale buffers
    • Handles FlashInfer-specific packed/quantized tensor layouts for FP8/MXFP4/NVFP4 paths
    • Includes explicit scale-buffer sizing for FP4 grouped kernels where FlashInfer requires aligned per-group scale storage
  • Runtime dispatch

    • Routes each dtype combination to the matching FlashInfer kernel in vendor_impl_run()
    • Guards unsupported or unexpected dtype combinations with explicit variant checks

Example dispatch shape mapping:

# Existing MoE dispatch layout
scatter_tokens.shape == (dispatch_tokens, hidden_size)
experts_weight.shape == (num_experts_per_rank, new_hidden_size, hidden_size)

segment_offsets = [
    *expert_dispatch_token_offset,
    expert_dispatch_token_offset[-1] + expert_dispatch_token_count[-1],
]

# FlashInfer grouped launch
y = flashinfer.grouped_mm_bf16(
    scatter_tokens,
    experts_weight,
    segment_offsets,
    out_dtype=torch.bfloat16,
)
Original prompt

Overview

Implement a FlashInfer-based provider for the moe_quant_group_gemm operation in the xpu-perf benchmark framework. The implementation should support the following data types:

  1. BF16 — using flashinfer.grouped_mm_bf16
  2. FP8 per-tensor scale — using flashinfer.grouped_mm_fp8 with per-tensor alpha
  3. FP8 block scale — using flashinfer.gemm.group_gemm_fp8_nt_groupwise with block-wise scaling
  4. MXFP4 — using flashinfer.gemm.group_gemm_mxfp4_nt_groupwise
  5. NVFP4 — using flashinfer.gemm.group_gemm_nvfp4_nt_groupwise

Context

The existing base implementation is in projects/micro_perf/op_defs/llm_ops/moe_quant_group_gemm.py. It uses a Python loop over experts with fake_quant_gemm to simulate INT8 group GEMM. The new FlashInfer implementation should replace this loop with a single FlashInfer kernel call for each data type variant.

Reference: Base Implementation Structure

The base class MoeQuantGroupGemmOp in projects/micro_perf/op_defs/llm_ops/moe_quant_group_gemm.py:

  • Uses prepare_args() to parse MoE parameters (num_tokens, hidden_size, new_hidden_size, num_experts, topk, ep_size, etc.)
  • Uses get_moe_tokens_info() to compute routing/dispatch metadata
  • Uses vendor_impl() to define input/output tensor info and set up the computation
  • Uses vendor_impl_run(tensor_mapping) to execute the actual GEMM

Implementation Requirements

Create a new file at projects/micro_perf/op_defs/llm_ops/moe_quant_group_gemm_flashinfer.py that:

  1. Registers a FlashInfer provider using @ProviderRegistry.register_provider_impl("moe_quant_group_gemm", "FlashInfer") decorator pattern.

  2. Inherits from the base MoeQuantGroupGemmOp class and overrides vendor_parser(), vendor_impl(), and vendor_impl_run().

  3. Supports multiple dtype configurations dispatched by self.dtype, self.w_dtype, self.compute_dtype, self.dst_dtype:

    • BF16: dtype=bfloat16, w_dtype=bfloat16, dst_dtype=bfloat16

      • Use flashinfer.grouped_mm_bf16(a, b, m_indptr, out_dtype=...)
      • a shape: (dispatch_tokens, hidden_size) bf16
      • b shape: (num_experts_per_rank, new_hidden_size, hidden_size) bf16
      • m_indptr shape: (num_experts_per_rank + 1,) int32, built from expert_dispatch_token_offset
    • FP8 per-tensor: dtype=fp8_e4m3, w_dtype=fp8_e4m3, compute_dtype=fp8, dst_dtype=bfloat16

      • Use flashinfer.grouped_mm_fp8(a, b, m_indptr, alpha=alpha, out_dtype=...)
      • a shape: (dispatch_tokens, hidden_size) float8_e4m3fn
      • b shape: (num_experts_per_rank, new_hidden_size, hidden_size) float8_e4m3fn
      • alpha: scalar float32 tensor (per-tensor scale)
    • FP8 block scale: dtype=fp8_e4m3, w_dtype=fp8_e4m3, compute_dtype=fp8_block, dst_dtype=bfloat16

      • Use flashinfer.gemm.group_gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, segment_offsets, out=...)
      • a shape: (dispatch_tokens, hidden_size) float8_e4m3fn
      • b shape: (num_experts_per_rank, new_hidden_size, hidden_size) float8_e4m3fn
      • a_scale shape: (hidden_size // 128, dispatch_tokens) float32
      • b_scale shape: (num_experts_per_rank, hidden_size // 128, new_hidden_size // 128) float32
      • segment_offsets shape: (num_experts_per_rank + 1,) int32
    • MXFP4: dtype=mxfp4, w_dtype=mxfp4, compute_dtype=mxfp4, dst_dtype=bfloat16

      • Use flashinfer.gemm.group_gemm_mxfp4_nt_groupwise(a, b, a_scale, b_scale, segment_offsets, out=...)
      • a shape: (dispatch_tokens, hidden_size) float8_e4m3fn (activation in fp8 for mxfp4 kernel)
      • b shape: (num_experts_per_rank, new_hidden_size, hidden_size // 2) uint8 (packed fp4)
      • a_scale shape: (dispatch_tokens_aligned // 128, hidden_size // 32) uint8
      • b_scale shape: (num_experts_per_rank, new_hidden_size_aligned // 128, hidden_size // 32) uint8
    • NVFP4: dtype=nvfp4, w_dtype=nvfp4, compute_dtype=nvfp4, dst_dtype=bfloat16

      • Use flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a, b, a_scale, b_scale, segment_offsets, out=...)
      • Similar to MXFP4 but using nvfp4 quantization format
      • a shape: (dispatch_tokens, hidden_size // 2) uint8 (packed fp4)
      • b shape: (num_experts_per_rank, new_hidden_size, hidden_size // 2) uint8
      • a_scale shape: (dispatch_tokens_aligned // 128, hidden_size // 16) uint8
      • b_scale shape: (num_experts_per_rank, new_hidden_size_aligned // 128, hidden_size // 16) uint8
  4. Build m_indptr / segment_offsets from the existing expert_dispatch_token_offset and expert_dispatch_token_count arrays that are computed by get_moe_tokens_info().

  5. Compute FLOPs correctly: 2 * dispatch_tokens * hidden_size * new_hidden_size

  6. Follow the same pattern as the base implementation for input_tensor_info, output_tensor_info, tensor size calculations, and the _create_tensors_func / _run_func setup.

Key FlashInfer API signatures (for referenc...

This pull request was created from Copilot chat.

Copilot AI changed the title [WIP] Implement FlashInfer provider for moe_quant_group_gemm operation Add FlashInfer provider for moe_quant_group_gemm across BF16, FP8, MXFP4, and NVFP4 May 21, 2026
Copilot AI requested a review from yupengzh-intel May 21, 2026 01:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants