Skip to content

Conversation

@sanchitintel
Copy link

@sanchitintel sanchitintel commented Nov 3, 2025

Summary

MoE GEMM without using collectives (uses CuTe interface). Can be called directly from another GPU kernel.
New copy, MMA atoms have been used. Some of the code has been adapted from the new API example.
The MMA code is similar to Grouped GEMM MMA collective code. Epilogue was eliminated by using FP32 -> BF16/FP16 reorder & store
The implementation is extensible & MXFP4/MXFP8 GEMMs can also be added.
Currently, the example uses RowMajor B because ColumnMajor B performance is worse for the underlying vanilla GEMM.
To integrate this code in a framework, users should use appropriate TiledMMAs (with suitable WG, SG tile shapes), copy atoms & rasterization.

Introduction

For Mixture of Experts used in Deep Learning models such as LLMs, the MoE GEMM use-case is something like this - each expert (corresponding to a group) has an associated weight sized N * K, which essentially a column-major B matrix (serving frameworks may change it to RowMajor as well). All the B matrices are contiguous w.r.t. each other, i.e. their total size is num_groups * N * K. M for each group (an individual GEMM problem) may be different. All A matrices are also contiguous w.r.t. each other. Each set of tokens routed to an expert makes up the A matrix for that group.

Since a token may be routed to multiple experts, they're duplicated such that the tokens' activations being routed to an expert are contiguous. This happens before MoE Grouped GEMMs are called, though.

When multiple GEMMs are to be computed, each with its own canonical A, B, C, D matrices, GroupGEMM is useful for ensuring high GPU utilization & preventing launch overhead that'd otherwise occur for multiple GEMM kernel launches. In cutlass, the vanilla GroupGEMM uses a persistent kernel approach - the number of workgroups launched are equal to the number of Xe cores, and they loop through until they have work, (in this case, work, is the mainloop to compute one of the output tiles of any one of the GEMMs we try to compute with the GroupGEMM API).

MoEGEMM thus seems to be a natural candidate for leveraging GroupedGEMM. However, GroupedGEMM's API requires providing pointers for A, B, C, D of each GEMM problem, as well as a vector of input shapes of individual GEMM problems, so the API interface is not suitable for MoE GEMM use-case. The MoE GEMM interface is quite cleaner. Moreover, MoE GEMM doesn't use the canonical C matrix.

Performance

When M-occupancy is high & all the individual GEMM problems are compute-bound, the throughput is close to peak performance.

When M-dimension of individual GEMM problems is small (typical for decoding), and WG_M, SG_M are also small (e.g. 8), then the implementation attains close to peak memory bandwidth utilization.

MoE GEMM is unlike vanilla GEMM. At runtime, the token distribution, especially for prefill, may be highly skewed, so M-occupancy may be low. If M-occupancy is 60%, that means 40% of the MMA compute is wasteful.

Using smaller WG_M would improve M-occupancy, but using smaller A tiles means more frequent memory transfers but less compute than with larger A WG tiles, so they usually only help when M dim of an individual GEMM problem <= WG_M. Otherwise, they don't let us use the hardware as efficiently.

Requirements

Please use igc 2.20, or newer

Please set minimum GPU frequency value to the max frequency value for benchmarking.

# Set GPU ID of the GPU whose min frequency you wish to change to max frequency
export GPUID = 0

sudo sh -c "cat /sys/class/drm/card$GPUID/device/tile0/gt0/freq0/max_freq > /sys/class/drm/card$GPUID/device/tile0/gt0/freq0/min_freq"

Build instructions

Please do not use -DDPCPP_HOST_COMPILER=g++-13 (for now, I'll later revise the code to make it compatible with g++. It's related to the sycl kernel launch).

​source /opt/intel/oneapi/setvars.sh
export ONEAPI_DEVICE_SELECTOR=level_zero:gpu
export CMAKE_BUILD_TYPE=Release
export IGC_VISAOptions="-perfmodel"
export IGC_VectorAliasBBThreshold=100000000000
export IGC_ExtraOCLOptions="-cl-intel-256-GRF-per-thread" 
export CC=icx
export CXX=icpx 
mkdir build; cd build
​cmake .. -GNinja -DCUTLASS_ENABLE_EXAMPLES=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON  -DCUTLASS_ENABLE_SYCL=ON -DCUTLASS_SYCL_PROFILING_ENABLED=ON -DCUTLASS_ENABLE_BENCHMARKS=OFF -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fdiagnostics-color=always" -DDPCPP_SYCL_TARGET=intel_gpu_bmg_g21 
ninja examples/12_bmg_moe_gemm_cute_interface/all
./examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface

cc @EikanWang @CaoZhongZ

@pengzhao-intel
Copy link

what datatype is covered by this PR and what the current performance?

}

reorder(tArA, tCrA);
reorder(tBrB, tCrB);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp16/bf16, what's the purpose of these two reorders? If we read the data of A and B with demand layout, we don't need to add reorder again.

Copy link
Author

@sanchitintel sanchitintel Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing!

If we read the data of A and B with demand layout, we don't need to add reorder again

In that case, it's a no-op. It's explicitly mentioned in the rearch documentation:

reorder acts as a "pipe" connecting copy and MMA operations (or any other subgroup-scope operations). With reorders, the kernel writer does not need to worry about perfectly matching layouts between copy and MMA atoms. In case the layouts do match perfectly (as make_block_2d_copy_{A,B,C} try to do), the compiler is able to remove the reorder entirely, making it a no-op.

Copy link

@tdeng5 tdeng5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have folders for MOE, like: 09_bmg_grouped_gemm_fp8, 10_bmg_grouped_gemm_mixed_dtype; do you do similar things? If yes, can we follow the existing naming convention for examples' folder.

@sanchitintel
Copy link
Author

We have folders for MOE, like: 09_bmg_grouped_gemm_fp8, 10_bmg_grouped_gemm_mixed_dtype; do you do similar things? If yes, can we follow the existing naming convention for examples' folder.

Will do, thanks!

@sanchitintel

This comment was marked as resolved.

@Antonyvance Antonyvance requested a review from Copilot November 5, 2025 07:17
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a Mixture of Experts (MoE) GEMM implementation for Intel GPUs using SYCL and CuTe, enabling MoE computations without collectives that can be called directly from GPU kernels.

Key Changes

  • New persistent tile scheduler for MoE GEMM workloads with custom work distribution
  • Core MoE GEMM kernels supporting both standard 16-bit floating-point and quantized MXFP4 formats
  • Example implementation demonstrating multi-expert GEMM execution with real workload patterns

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
examples/cute/tutorial/moe/moe_tile_scheduler.hpp Implements persistent tile scheduler adapted for MoE workloads with per-expert work distribution
examples/cute/tutorial/moe/moe_grouped_gemm.hpp Main MoE GEMM orchestration handling expert batching and tensor pointer updates
examples/cute/tutorial/moe/moe_gemms.hpp Device-side GEMM kernels with support for bf16/fp16 and MXFP4 quantized operations
examples/cute/tutorial/moe/moe_example.cpp Host-side launcher and example with realistic multi-layer expert workload patterns
examples/cute/tutorial/CMakeLists.txt Adds build target for MoE GEMM example
Comments suppressed due to low confidence (1)

examples/cute/tutorial/moe/moe_tile_scheduler.hpp:1

  • Corrected spelling of 'Othwerwise' to 'Otherwise' in comment at line 301 of moe_gemms.hpp
/***************************************************************************************************

@sanchitintel sanchitintel changed the title MoE GEMM example BMG/PVC BF16/FP16 MoE GEMM example Nov 7, 2025
@sanchitintel sanchitintel changed the title BMG/PVC BF16/FP16 MoE GEMM example BF16/FP16 MoE GEMM with cute interface example Nov 7, 2025
@sanchitintel sanchitintel marked this pull request as ready for review November 7, 2025 06:26
@sanchitintel sanchitintel requested review from jiyang1011, pengzhao-intel, rolandschulz and tdeng5 and removed request for pengzhao-intel November 7, 2025 06:28
Copy link

@jiyang1011 jiyang1011 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -0,0 +1,347 @@
/***************************************************************************************************

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no special in this moe tile scheduler. Why not use the persisent scheduler that cutlass supplied?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or @sanchitintel could you help me point out your motivation?

Copy link
Author

@sanchitintel sanchitintel Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jiyang1011, thanks for reviewing the PR!

This file has some custom code, as MoE GEMM is a special case of Grouped GEMM.

The GroupedGEMM API requires creation of a vector of ProblemShape objects for each GEMM problem, which is used in the GroupedGEMM tile-scheduler. If there are 32 groups, then a vector of 32 ProblemShape objects is created.
Since these would not be known at compile time for a framework, they would have to be created at run-time instead.
However, for MoEGEMM, I just provide one dummy shape, and then the custom code in tile scheduler can derive the shape of each GEMM problem.

Also, another reason for having a separate scheduler is that I'll change the scheduling algorithm down the line.

Thanks!

@sanchitintel sanchitintel changed the title BF16/FP16 MoE GEMM with cute interface example BF16/FP16 MoE GEMM with CuTe interface example Nov 7, 2025
@sanchitintel
Copy link
Author

sanchitintel commented Nov 7, 2025

Hi @rolandschulz @ratnampa, can you please clarify why the CI is using -DDPCPP_HOST_COMPILER=g++-13?
Please clarify why it's not using icpx.

Thanks!

sanchitintel added a commit that referenced this pull request Nov 7, 2025
-DDPCPP_HOST_COMPILER=g++-13 is causing some issues in #600.
@ratnampa
Copy link

ratnampa commented Nov 7, 2025

Hi @rolandschulz @ratnampa, can you please clarify why the CI is using -DDPCPP_HOST_COMPILER=g++-13? Please clarify why it's not using icpx.

Thanks!

@sanchitintel this for g++ support as the host compiler and not use icpx as host compiler.

If I change CMakeLists.txt with `if (CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM")`, that somehow also works when g++ is used as host compiler, although it should not, so making corresponding changes in C++ instead.
Added corresponding changes in CMake
@sanchitintel sanchitintel changed the title BF16/FP16 MoE GEMM with CuTe interface example Example of BF16/FP16 MoE Grouped GEMM with CuTe interface Nov 7, 2025
Avoid extra is_valid() calls after group changes
Fixed a bug for the case in which the first GEMM problem's output has less than xe_core_count WG tiles. In practice, this issue can't happen on BMG with OOB models.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants