Skip to content

Conversation

tom-pollak
Copy link

The triton_kernels.routing module was deprecated and removed in triton commit 30ede52aa (triton-lang/triton#8375).

Replaced deprecated routing() call with new primitives in compute_routing().


Upgrade to triton>=3.5. triton_kernels HEAD uses on tl.target_info() that is not available in 3.4.

@tom-pollak tom-pollak force-pushed the migrate-routing-squash branch from 3557af1 to 04ebbfd Compare October 13, 2025 22:36
The `triton_kernels.routing` module was deprecated and removed in triton
commit 30ede52aa (triton-lang/triton#8375).

Replaced deprecated `routing()` call with new primitives in
`compute_routing()`.

---

Upgrade to `triton>=3.5`. `triton_kernels` HEAD uses on
`tl.target_info()` that is not available in 3.4.
@tom-pollak tom-pollak force-pushed the migrate-routing-squash branch from 04ebbfd to 76cf668 Compare October 13, 2025 22:38
Copy link

💡 Codex Review

dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0]
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
rdata = RoutingData(
gate_scal,
ragged_batch_metadata.batch_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)

P0 Badge Arrange gate scalars using column-sorted indices

In compute_routing() the gate weights are built with gate_scal = sparse_logits.vals.flatten()[combine_indx], where combine_indx is mask_metadata.row_sorted_indx. However the ragged metadata and dispatch/scatter indices are derived from mask_metadata.col_sorted_indx/col_sum (expert-major ordering). This means RoutingData.gate_scal is grouped by token rows while the metadata expects segments grouped by expert columns, so when an expert receives more than one token the subsequent matmul_ogs(..., gammas=rdata.gate_scal) multiplies outputs by weights belonging to different experts. Any configuration with multiple experts or top‑k>1 will therefore return incorrect routing probabilities.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@tom-pollak
Copy link
Author

@codex I don't think that's correct, that doesn't match the original PR: triton-lang/triton#8375

def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
    sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
    dispatch_indx = sparse_logits.mask.metadata.col_sorted_indx
    combine_indx = sparse_logits.mask.metadata.row_sorted_indx
    ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask.metadata.col_sum, dispatch_indx.shape[0])
    gate_scal = sparse_logits.vals.flatten()[combine_indx] ###### combine_indx, not dispatch.
    routing_data = RoutingData(gate_scal, ragged_batch_metadata.batch_sizes, n_expts_tot, n_expts_act,
                               ragged_batch_metadata)
    gather_idx = GatherIndx(combine_indx, dispatch_indx)
    scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
    return routing_data, gather_idx, scatter_idx


def legacy_routing(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
    if sm_first:
        logits = torch.softmax(logits, dim=-1)
    sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
    return legacy_routing_from_bitmatrix(sparse_logits.mask, sparse_logits.vals, sparse_logits.indx, logits.shape[-1],
                                         n_expts_act)

I think it's confused by row_sorted_indx is reference col-major positions, but are grouped by row.

Copy link

Note

To use Codex here, create an environment for this repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant