From 76cf668b07dc990b147ef4f902ced941cd7d70b0 Mon Sep 17 00:00:00 2001 From: Tom Pollak Date: Mon, 13 Oct 2025 22:38:10 +0000 Subject: [PATCH] fix: port moe routing to new triton_kernels API The `triton_kernels.routing` module was deprecated and removed in triton commit 30ede52aa (https://github.com/triton-lang/triton/pull/8375). Replaced deprecated `routing()` call with new primitives in `compute_routing()`. --- Upgrade to `triton>=3.5`. `triton_kernels` HEAD uses on `tl.target_info()` that is not available in 3.4. --- gpt_oss/triton/moe.py | 35 +++++++++++++++++++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/gpt_oss/triton/moe.py b/gpt_oss/triton/moe.py index 925dbd54..ccffaa44 100644 --- a/gpt_oss/triton/moe.py +++ b/gpt_oss/triton/moe.py @@ -5,10 +5,10 @@ import triton_kernels.swiglu from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation -from triton_kernels.matmul_ogs import matmul_ogs +from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx from triton_kernels.numerics import InFlexData -from triton_kernels.routing import routing -from triton_kernels.tensor import convert_layout +from triton_kernels.topk import topk +from triton_kernels.tensor import convert_layout, make_ragged_tensor_metadata from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout from triton_kernels.tensor import wrap_torch_tensor, FP4 @@ -31,6 +31,32 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True return out_glu * (x_linear + 1) +def compute_routing(logits, n_expts_act, n_expts_tot): + sparse_logits = topk(logits, n_expts_act, apply_softmax=True) + + dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx + combine_indx = sparse_logits.mask_metadata.row_sorted_indx + + ragged_batch_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, + dispatch_indx.shape[0] + ) + + gate_scal = sparse_logits.vals.flatten()[combine_indx] + + rdata = RoutingData( + gate_scal, + ragged_batch_metadata.batch_sizes, + n_expts_tot, + n_expts_act, + ragged_batch_metadata + ) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) + + return rdata, gather_indx, scatter_indx + + def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True): if x.numel() == 0: return x @@ -41,8 +67,9 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert with record_function("wg"): logits = matmul_ogs(x, wg, bg, precision_config=pcg) + with record_function("routing"): - rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1) + rdata, gather_indx, scatter_indx = compute_routing(logits, experts_per_token, num_experts) if fused_act: assert interleaved, "Fused activation requires interleaved weights" diff --git a/pyproject.toml b/pyproject.toml index d2595a16..e1e0933f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ requires-python = ">=3.12" version = "0.0.8" [project.optional-dependencies] -triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"] +triton = ["triton>=3.5", "safetensors>=0.5.3", "torch>=2.7.0"] torch = ["safetensors>=0.5.3", "torch>=2.7.0"] metal = ["numpy", "tqdm", "safetensors", "torch"] test = ["pytest>=8.4.1", "httpx>=0.28.1"]