[Common] Add dense router output for fused router#3129
Conversation
Greptile SummaryThis PR adds an optional dense
Confidence Score: 5/5Safe to merge. The dense-index path is a fully additive feature with no changes to existing routing-map entry points; all new guard checks are defensive and correct. Both forward and backward kernels are mathematically consistent with the existing routing-map paths across all three score functions and both No files require special attention. The new Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
C --> E{use_radix?}
D --> F{use_radix?}
E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
G & H & I & J --> K["probs + routing_output returned to Python"]
K --> L{Backward}
L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
C --> E{use_radix?}
D --> F{use_radix?}
E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
G & H & I & J --> K["probs + routing_output returned to Python"]
K --> L{Backward}
L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
Reviews (4): Last reviewed commit: "[PyTorch] Preserve router leading dimens..." | Re-trigger Greptile |
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
49c6553 to
6a957c3
Compare
|
/te-ci L0 |
denera
left a comment
There was a problem hiding this comment.
LGTM, looks like all the issues are addressed already. CI is clean except for unrelated known JAX issue (reran license failure and it passed). Thanks!
The fused router forward kernels currently produce boolean or bitmap sparse routing map format, which may need an additional conversion to
topk_indicesformat (in the shape of[*leading_dims, top_k]) to be passed to the dispatcher. For example, NCCL EP acceptstopk_indicesas routing map. To avoid needing an extra kernel for routing map conversion, the fused router could directly write into that format.For NCCL EP, the dense
topk_indicesrow is consumed as an order-insensitive selected-expert set. The dense output therefore does not promise score-sorted or expert-sorted order; it preserves the selected experts produced by the chosen top-k kernel path.Summary
topk_indicesoutput path to fused router top-k.[*leading_dims, topk]index buffer.int16,int32, andint64dense index buffers.int16support for TE CUDA graph weak-ref tensors.Testing
NVTE_BUILD_THREADS_PER_JOB=4 NVTE_CUDA_ARCHS="90;100;103a;120" NVTE_USE_CCACHE=1 pip install --no-build-isolation -e .[test] --verbosepython -m pytest -q tests/pytorch/test_fused_router.py3203 passed, 444 skipped, 3 warnings in 44.17s