Skip to content

Commit c81153a

Browse files
authored
Merge pull request #83 from ROCm/qianghan/add-turbo-grouped-gemm
Add Primus Turbo grouped GEMM support for MoE sparse matmul
2 parents 851c093 + fe25475 commit c81153a

3 files changed

Lines changed: 42 additions & 2 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ num_experts: 1
174174
num_experts_per_tok: 1
175175
megablox: True
176176
sparse_matmul: True
177+
use_turbo_grouped_gemm: false # Use Primus Turbo grouped GEMM for MoE sparse matmul. Requires sparse_matmul=True, megablox=False, and primus_turbo installed.
177178
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
178179
load_balance_loss_weight: 0.01 # weight for the load balance loss
179180
expert_balance: False # whether or not to do expert balancing

src/MaxText/configs/types.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,12 @@ class MoEKernels(BaseModel):
556556

557557
megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.")
558558
sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.")
559+
use_turbo_grouped_gemm: bool = Field(
560+
False,
561+
description="Use Primus Turbo grouped GEMM for MoE sparse matmul. "
562+
"Requires sparse_matmul=True and megablox=False. "
563+
"Requires the primus_turbo package to be installed.",
564+
)
559565
wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.")
560566
wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.")
561567
wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.")
@@ -1094,7 +1100,8 @@ class DevelopmentAndDebugging(BaseModel):
10941100
)
10951101
jax_distributed_initialization_timeout: int = Field(300, description="Timeout for jax.distributed.initialize.")
10961102
jax_distributed_heartbeat_timeout_seconds: int = Field(
1097-
100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores."
1103+
100,
1104+
description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.",
10981105
)
10991106
jax_debug_log_modules: str = Field("", description="Set to 'jax' for verbose JAX logging.")
11001107
skip_jax_distributed_system: bool = Field(False, description="If True, do not initialize the jax distributed system.")
@@ -1899,6 +1906,13 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
18991906
)
19001907
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
19011908
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
1909+
if self.use_turbo_grouped_gemm:
1910+
if self.quantization:
1911+
raise ValueError("use_turbo_grouped_gemm is not compatible with quantization.")
1912+
if not self.sparse_matmul:
1913+
raise ValueError("use_turbo_grouped_gemm requires sparse_matmul=True.")
1914+
if self.megablox:
1915+
raise ValueError("use_turbo_grouped_gemm requires megablox=False.")
19021916
if self.use_multimodal:
19031917
valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e")
19041918
if self.model_name not in valid_mm_models and self.model_name != "default":

src/MaxText/layers/moe.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,28 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
855855
use_tokamax_backend=self.config.use_tokamax_gmm,
856856
is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp,
857857
)
858+
elif self.config.use_turbo_grouped_gemm:
859+
try:
860+
from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm
861+
except ImportError:
862+
raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.")
863+
if not getattr(turbo_grouped_gemm, "_logged", False):
864+
max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")
865+
turbo_grouped_gemm._logged = True
866+
# Thread-local x64: CK kernel requires int64 group_sizes, but
867+
# global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch).
868+
# Use jax.experimental.enable_x64() for thread-local scope,
869+
# safe for concurrent shard_map threads.
870+
# Remove this once primus_turbo accepts int32 group_lens natively.
871+
with jax.experimental.enable_x64():
872+
output = turbo_grouped_gemm(
873+
inputs,
874+
kernel,
875+
group_sizes.astype(jnp.int64),
876+
transA=False,
877+
transB=False,
878+
num_cu=-1,
879+
)
858880
else:
859881
rhs_inputs = kernel
860882
if isinstance(kernel, aqt.QTensor):
@@ -1445,7 +1467,10 @@ def get_einsum(
14451467
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
14461468
# simply skip kwargs, since aqt einsum doesn't support any kwargs
14471469
# like precision
1448-
is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) )
1470+
is_aqt = not (
1471+
isinstance(self.quant, quantizations.Fp8Quantization)
1472+
or isinstance(self.quant, quantizations.NANOOFp8Quantization)
1473+
)
14491474
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
14501475
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error
14511476

0 commit comments

Comments
 (0)