Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions gpt_oss/triton/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import triton_kernels.swiglu
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import routing
from triton_kernels.tensor import convert_layout
from triton_kernels.topk import topk
from triton_kernels.tensor import convert_layout, make_ragged_tensor_metadata
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
from triton_kernels.tensor import wrap_torch_tensor, FP4

Expand All @@ -31,6 +31,32 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True
return out_glu * (x_linear + 1)


def compute_routing(logits, n_expts_act, n_expts_tot):
sparse_logits = topk(logits, n_expts_act, apply_softmax=True)

dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
combine_indx = sparse_logits.mask_metadata.row_sorted_indx

ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0]
)

gate_scal = sparse_logits.vals.flatten()[combine_indx]

rdata = RoutingData(
gate_scal,
ragged_batch_metadata.batch_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)

return rdata, gather_indx, scatter_indx


def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):
if x.numel() == 0:
return x
Expand All @@ -41,8 +67,9 @@ def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_expert

with record_function("wg"):
logits = matmul_ogs(x, wg, bg, precision_config=pcg)

with record_function("routing"):
rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)
rdata, gather_indx, scatter_indx = compute_routing(logits, experts_per_token, num_experts)

if fused_act:
assert interleaved, "Fused activation requires interleaved weights"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ requires-python = ">=3.12"
version = "0.0.8"

[project.optional-dependencies]
triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"]
triton = ["triton>=3.5", "safetensors>=0.5.3", "torch>=2.7.0"]
torch = ["safetensors>=0.5.3", "torch>=2.7.0"]
metal = ["numpy", "tqdm", "safetensors", "torch"]
test = ["pytest>=8.4.1", "httpx>=0.28.1"]
Expand Down