File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change 1616 */
1717
1818#pragma once
19- #include < string>
2019#include < sstream>
20+ #include < string>
2121#include " cub/cub.cuh"
2222
2323namespace phi {
@@ -45,7 +45,10 @@ class CubKeyValueSorter {
4545 size_t getWorkspaceSize (const size_t num_key_value_pairs,
4646 bool descending = false ) {
4747 num_key_value_pairs_ = num_key_value_pairs;
48- size_t required_storage = 0 ;
48+ // Initialize to 1 as workaround: under CUDA Graph capture, CUB may not
49+ // write to required_storage, and 1 is the minimum expected size in that
50+ // scenario.
51+ size_t required_storage = 1 ;
4952 int * null_int = nullptr ;
5053 if (descending) {
5154 cub::DeviceRadixSort::SortPairsDescending (NULL ,
Original file line number Diff line number Diff line change @@ -87,6 +87,13 @@ void MoeDispatchKernel(
8787 int8_t *sorter_ws_ptr = reinterpret_cast <int8_t *>(ws_ptr + bytes);
8888 int *permuted_experts_ =
8989 reinterpret_cast <int *>(sorter_ws_ptr + sorter_ws_size_bytes);
90+ // If expected_ws_size > workspace_size ever occurs in sorter_.run (which
91+ // should be practically impossible), there is a contiguous, currently unused
92+ // region (permuted_experts_) right after sorter_ws_ptr. In practice, this
93+ // region is larger than what cub::DeviceRadixSort::SortPairs requires.
94+ // However, relying on this to “work” after canceling the assertion is unsafe:
95+ // it constitutes undefined behavior, and there is no guarantee it will remain
96+ // correct across inputs, CUDA/CUB versions, or architectures.
9097 int *permuted_rows_ = permuted_experts_ + num_moe_inputs;
9198
9299 int *topk_idx_ptr = topk_idx->data <int >();
You can’t perform that action at this time.
0 commit comments