[Common] Add dense router output for fused router#3129
Conversation
Greptile SummaryThis PR adds an optional dense
Confidence Score: 5/5Safe to merge. The dense-index path is a fully additive feature with no changes to existing routing-map entry points; all new guard checks are defensive and correct. Both forward and backward kernels are mathematically consistent with the existing routing-map paths across all three score functions and both No files require special attention. The new Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
C --> E{use_radix?}
D --> F{use_radix?}
E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
G & H & I & J --> K["probs + routing_output returned to Python"]
K --> L{Backward}
L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["fused_topk_with_score_function (Python)"] --> B{topk_indices provided?}
B -- No --> C["nvte_fused_topk_with_score_function_forward_v2\n(BYTEMAP / BITMAP_U8 routing map)"]
B -- Yes --> D["nvte_fused_topk_with_score_function_forward_with_indices\n(dense int16/int32/int64 index buffer)"]
C --> E{use_radix?}
D --> F{use_radix?}
E -- No --> G["fused_topk_forward_simple_kernel\n(Naive, routing_map written)"]
E -- Yes --> H["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map written)"]
F -- No --> I["fused_topk_forward_simple_kernel\n(Naive, routing_map=nullptr, topk_indices_output written)"]
F -- Yes --> J["fused_topk_with_score_function_forward_kernel\n(Radix, routing_map=nullptr, topk_indices_output written)"]
G & H & I & J --> K["probs + routing_output returned to Python"]
K --> L{Backward}
L -- use_dense_indices=False --> M["nvte_fused_topk_with_score_function_backward_v2\n(routing_map → grad_logits)"]
L -- use_dense_indices=True --> N["nvte_fused_topk_with_score_function_backward_with_indices\n(fused_topk_backward_selected_indices_kernel)"]
Reviews (4): Last reviewed commit: "[PyTorch] Preserve router leading dimens..." | Re-trigger Greptile |
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
49c6553 to
6a957c3
Compare
|
/te-ci L0 |
The fused router forward kernels currently produce boolean or bitmap sparse routing map format, which may need an additional conversion to
topk_indicesformat (in the shape of[*leading_dims, top_k]) to be passed to the dispatcher. For example, NCCL EP acceptstopk_indicesas routing map. To avoid needing an extra kernel for routing map conversion, the fused router could directly write into that format.For NCCL EP, the dense
topk_indicesrow is consumed as an order-insensitive selected-expert set. The dense output therefore does not promise score-sorted or expert-sorted order; it preserves the selected experts produced by the chosen top-k kernel path.Summary
topk_indicesoutput path to fused router top-k.[*leading_dims, topk]index buffer.int16,int32, andint64dense index buffers.int16support for TE CUDA graph weak-ref tensors.Testing
NVTE_BUILD_THREADS_PER_JOB=4 NVTE_CUDA_ARCHS="90;100;103a;120" NVTE_USE_CCACHE=1 pip install --no-build-isolation -e .[test] --verbosepython -m pytest -q tests/pytorch/test_fused_router.py3203 passed, 444 skipped, 3 warnings in 44.17s