[MoE] Add swiglu_oai (OAI SwiGLU) for per-token fp8 CK XDL 2-stage MoE#3886
Open
LJ-underdog wants to merge 1 commit into
Open
[MoE] Add swiglu_oai (OAI SwiGLU) for per-token fp8 CK XDL 2-stage MoE#3886LJ-underdog wants to merge 1 commit into
LJ-underdog wants to merge 1 commit into
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
7112b15 to
f0f75f8
Compare
1 task
01a55ec to
3d1291a
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
Adds OAI-style SwiGLU (swiglu_oai) support to the CK XDL 2-stage MoE path (including per-token FP8/PTPC) by plumbing ActivationType.Swiglu through codegen and runtime activation mapping so GPT-OSS/OAI-style MoE models can run on this kernel family.
Changes:
- Add/standardize CK stage1 activation-op mapping for
Swiglu(CK ActOP=3) in both runtime dispatch (gemm_moe_ck2stages.cu) and codegen (gemm_moe_ck2stages_common.py,gen_instances.py). - Extend CK 2-stage instance generation to produce AOT “plain-f8” swiglu instances for
{per_tensor, per_token} × {f16, b16}outputs. - Make activation string parsing explicit and add
"swiglu"to supported activation strings (aiter/utility/dtypes.py).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| csrc/ck_gemm_moe_2stages_codegen/gen_instances.py | Adds swiglu act-op mapping + AOT generation loop for plain-f8 swiglu instances; updates activation CLI choices. |
| csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu | Replaces boolean inversion with explicit ActivationType→CK ActOP mapping (adds Swiglu support). |
| csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py | Introduces shared ACT_OP_MAP/ACT_OP_NAME and switches GEMM1 kernel naming/config from bool ActOP to int ActOP. |
| aiter/utility/dtypes.py | Updates str2ActivationType to explicitly accept "swiglu" (and provide clearer errors). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
3d1291a to
e611971
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Enable the OAI-form SwiGLU activation (
swiglu_oai,gate * sigmoid(1.702 * gate) * (up + 1), gpt-oss style) for the per-token fp8 (PTPC) CK XDL 2-stage MoE path. This path currently supports only silu/gelu, so gpt-oss / OAI-style MoE models cannot run on it.Technical Details
fused_moe.py: addAITER_SWIGLU_CK2STAGEescape hatch and routeActivationType.Swigluinto the 2-stage path.utility/dtypes.py: swiglu activation plumbing.gemm_moe_ck2stages.cu/gemm_moe_ck2stages_common.py:map_activation_to_ck_stage1mapsSwigluto CK ActOP 3.gen_instances.py: addACT_OP_MAP, theswigluargparse choice, and plain-fp8 per_token / per_tensor swiglu codegen instances.gridwise_moe_gemm.hpp(per_token fp8 -> a8w8 tag, plain cuh), the same path as silu/gelu. The activation is computed in fp32 in the epilogue and is orthogonal to the GEMM compute (MFMA/tile/pipeline untouched) and to quantization (existing per-token dequant reused).swiglu_oai_and_mul).Test Plan
Op-isolate on gfx942 (MI308X): run the per-token fp8 swiglu_oai MoE via JIT on-demand (no manual codegen patch) and compare the output against a torch fp32 OAI-SwiGLU reference (cos_sim).
Test Result
JIT auto-generates and compiles the
swiglu x per_token x f8instances and the op runs without error. cos_sim = 0.999993 vs the torch fp32 OAI-SwiGLU reference; no NaN; dispatched to the CK 2-stageGridwiseMoeGemmkernel (verified via rocprofv3).Submission Checklist