Skip to content

[Common] Add dense router output for fused router#3129

Merged
denera merged 8 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_dense_route
Jun 23, 2026
Merged

[Common] Add dense router output for fused router#3129
denera merged 8 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_dense_route

Conversation

@harryzhou2000

@harryzhou2000 harryzhou2000 commented Jun 15, 2026

Copy link
Copy Markdown
Member

The fused router forward kernels currently produce boolean or bitmap sparse routing map format, which may need an additional conversion to topk_indices format (in the shape of [*leading_dims, top_k]) to be passed to the dispatcher. For example, NCCL EP accepts topk_indices as routing map. To avoid needing an extra kernel for routing map conversion, the fused router could directly write into that format.

For NCCL EP, the dense topk_indices row is consumed as an order-insensitive selected-expert set. The dense output therefore does not promise score-sorted or expert-sorted order; it preserves the selected experts produced by the chosen top-k kernel path.

Summary

  • Add an optional dense topk_indices output path to fused router top-k.
  • Avoid materializing the full routing map when callers provide a [*leading_dims, topk] index buffer.
  • Add dense-index backward using the selected expert indices directly.
  • Support int16, int32, and int64 dense index buffers.
  • Preserve existing BYTEMAP/BITMAP_U8 routing-map paths and p3R radix/naive dispatch behavior.
  • Add guard checks for dense index shape/device/dtype, grouped top-k assumptions, routing-map format, and direct score-function API usage.
  • Add int16 support for TE CUDA graph weak-ref tensors.

Testing

  • Built on B200 with:
    NVTE_BUILD_THREADS_PER_JOB=4 NVTE_CUDA_ARCHS="90;100;103a;120" NVTE_USE_CCACHE=1 pip install --no-build-isolation -e .[test] --verbose
  • Ran:
    python -m pytest -q tests/pytorch/test_fused_router.py
  • Result:
    3203 passed, 444 skipped, 3 warnings in 44.17s

@harryzhou2000 harryzhou2000 changed the title [Common] Add dense top-k index output for fused router [Common] Add dense router output for fused router Jun 15, 2026
@harryzhou2000 harryzhou2000 marked this pull request as ready for review June 15, 2026 12:48
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an optional dense topk_indices output path to the fused router forward (and matching backward), allowing callers such as NCCL EP to receive a [*leading_dims, topk] integer index buffer directly from the kernel instead of a separate routing-map conversion step.

  • Forward: a new _with_indices C++ entry point and CUDA kernel launcher (Naive + Radix paths) write selected expert IDs into a caller-supplied int16/int32/int64 output buffer while skipping routing-map materialization; existing BYTEMAP/BITMAP_U8 paths are fully preserved.
  • Backward: a new fused_topk_backward_selected_indices_kernel uses the saved dense indices to reconstruct gradients without a routing map; math mirrors the existing routing-map backward for all three score functions and both use_pre_softmax modes.
  • Hardening: get_score_function_value replaces silent map-insert lookups, and new guard checks cover format/shape/dtype/device for both forward and backward paths; int16 is added to TE_ROUTER_INDEX_TYPE_SWITCH_ALL and the numpy typestring table.

Confidence Score: 5/5

Safe to merge. The dense-index path is a fully additive feature with no changes to existing routing-map entry points; all new guard checks are defensive and correct.

Both forward and backward kernels are mathematically consistent with the existing routing-map paths across all three score functions and both use_pre_softmax modes. Shape/device/dtype validation is thorough at every layer. Existing BYTEMAP/BITMAP_U8 paths are structurally untouched. The only observation is a minor style nit (duplicate if topk_indices is not None guards in router.py) that has no functional impact.

No files require special attention. The new fused_topk_backward_selected_indices_kernel in fused_topk_with_score_function.cu is the most novel logic and has been validated by the existing parametric test suite including backward comparison.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds IndexType template parameter and topk_indices_output pointer to both forward kernels; new _with_indices launcher and backward kernel fused_topk_backward_selected_indices_kernel. Math is consistent with the routing-map backward; shared-memory init correctly guarded by routing_map != nullptr for the dense-index path.
transformer_engine/pytorch/csrc/extensions/router.cpp Replaces silent map-insert with get_score_function_value() helper, adds check_dense_topk_indices validation, branches forward/backward to dedicated _with_indices NVTE calls. Logic is correct and shape/device/dtype guards are thorough.
transformer_engine/pytorch/router.py Adds topk_indices optional arg to FusedTopkScoreFunction.forward/backward; use_dense_indices context flag routes to the new C++ entry point. Minor style issue with duplicate if topk_indices is not None guards.
transformer_engine/common/fused_router/utils.h Adds TE_ROUTER_DENSE_INDEX_TYPE_SWITCH_ALL macro (int16/int32/int64), extends existing TE_ROUTER_INDEX_TYPE_SWITCH_ALL with int16 case, and adds inline check_routing_map_format helper. All changes are straightforward.
tests/pytorch/test_fused_router.py Adds topk_output_mode / topk_index_dtype parameters to run_comparison; new test_topk_preserves_leading_dims test for multi-dim leading shapes; dense index parametrization added to all three score-function test suites covering forward + backward.
transformer_engine/common/include/transformer_engine/fused_router.h Declares two new public C API functions: nvte_fused_topk_with_score_function_forward_with_indices and nvte_fused_topk_with_score_function_backward_with_indices. Docstrings accurately reflect the new semantics.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds py::arg("topk_indices") and py::arg("use_dense_indices") to pybind11 bindings, matching updated C++ signatures.
transformer_engine/pytorch/utils.py Adds torch.int16 → "<i2" entry to the numpy typestring dict to support int16 CUDA graph weak-ref tensors.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Adds check_routing_map_format guard at the top of fused_score_for_moe_aux_loss_forward. One-line defensive hardening, no logic change.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
    B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
    B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
    C --> E{use_radix?}
    D --> F{use_radix?}
    E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
    E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
    F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
    F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
    G & H & I & J --> K["probs + routing_output returned to Python"]
    K --> L{Backward}
    L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
    L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
    B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
    B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
    C --> E{use_radix?}
    D --> F{use_radix?}
    E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
    E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
    F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
    F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
    G & H & I & J --> K["probs + routing_output returned to Python"]
    K --> L{Backward}
    L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
    L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
Loading

Reviews (4): Last reviewed commit: "[PyTorch] Preserve router leading dimens..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp
Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp
Comment thread transformer_engine/pytorch/router.py Outdated
@denera denera self-requested a review June 23, 2026 06:12
@denera denera self-assigned this Jun 23, 2026
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_dense_route branch from 49c6553 to 6a957c3 Compare June 23, 2026 06:42
@denera

denera commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

/te-ci L0

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, looks like all the issues are addressed already. CI is clean except for unrelated known JAX issue (reran license failure and it passed). Thanks!

@denera denera merged commit 77054fa into NVIDIA:main Jun 23, 2026
35 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants