Skip to content

Commit 147b2e5

Browse files
littledggCopilotEmmonsCurse
authored
[BugFix] Fix zero workspace returned by CUB size query under CUDA Graph in MoE dispatch (#5087)
* fix bug about CubKeyValueSorter::run * pre-commit and add comment * pre-commit * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * fix precommit --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: YuBaoku <[email protected]>
1 parent 0857099 commit 147b2e5

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

custom_ops/gpu_ops/moe/fused_moe_imp_op.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
*/
1717

1818
#pragma once
19-
#include <string>
2019
#include <sstream>
20+
#include <string>
2121
#include "cub/cub.cuh"
2222

2323
namespace phi {
@@ -45,7 +45,10 @@ class CubKeyValueSorter {
4545
size_t getWorkspaceSize(const size_t num_key_value_pairs,
4646
bool descending = false) {
4747
num_key_value_pairs_ = num_key_value_pairs;
48-
size_t required_storage = 0;
48+
// Initialize to 1 as workaround: under CUDA Graph capture, CUB may not
49+
// write to required_storage, and 1 is the minimum expected size in that
50+
// scenario.
51+
size_t required_storage = 1;
4952
int* null_int = nullptr;
5053
if (descending) {
5154
cub::DeviceRadixSort::SortPairsDescending(NULL,

custom_ops/gpu_ops/moe/moe_dispatch.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ void MoeDispatchKernel(
8787
int8_t *sorter_ws_ptr = reinterpret_cast<int8_t *>(ws_ptr + bytes);
8888
int *permuted_experts_ =
8989
reinterpret_cast<int *>(sorter_ws_ptr + sorter_ws_size_bytes);
90+
// If expected_ws_size > workspace_size ever occurs in sorter_.run (which
91+
// should be practically impossible), there is a contiguous, currently unused
92+
// region (permuted_experts_) right after sorter_ws_ptr. In practice, this
93+
// region is larger than what cub::DeviceRadixSort::SortPairs requires.
94+
// However, relying on this to “work” after canceling the assertion is unsafe:
95+
// it constitutes undefined behavior, and there is no guarantee it will remain
96+
// correct across inputs, CUDA/CUB versions, or architectures.
9097
int *permuted_rows_ = permuted_experts_ + num_moe_inputs;
9198

9299
int *topk_idx_ptr = topk_idx->data<int>();

0 commit comments

Comments
 (0)