@@ -47,19 +47,18 @@ static bool is_supported_dense_index_dtype(at::ScalarType dtype) {
4747}
4848
4949static void check_dense_topk_indices (const at::Tensor &topk_indices, const at::Tensor &ref,
50- int64_t num_tokens , int topk) {
50+ c10::IntArrayRef leading_dims , int topk) {
5151 TORCH_CHECK (topk_indices.is_cuda (), " topk_indices must be a CUDA tensor" );
5252 TORCH_CHECK (topk_indices.device () == ref.device (), " topk_indices must be on the same device as " ,
5353 " the logits/grad tensor" );
5454 TORCH_CHECK (topk_indices.is_contiguous (), " topk_indices must be contiguous" );
5555 TORCH_CHECK (is_supported_dense_index_dtype (topk_indices.scalar_type ()),
5656 " topk_indices dtype must be int16, int32, or int64, got " ,
5757 topk_indices.scalar_type ());
58- TORCH_CHECK (topk_indices.numel () == num_tokens * static_cast <int64_t >(topk),
59- " topk_indices must contain num_tokens * topk elements, got " , topk_indices.numel (),
60- " but expected " , num_tokens * static_cast <int64_t >(topk));
61- TORCH_CHECK (topk_indices.dim () >= 1 && topk_indices.size (-1 ) == topk,
62- " topk_indices last dimension must be topk=" , topk, " , got shape " ,
58+ std::vector<int64_t > expected_shape (leading_dims.begin (), leading_dims.end ());
59+ expected_shape.push_back (static_cast <int64_t >(topk));
60+ TORCH_CHECK (topk_indices.sizes () == expected_shape,
61+ " topk_indices shape must be [*leading_dims, topk]=" , expected_shape, " , got " ,
6362 topk_indices.sizes ());
6463}
6564
@@ -97,7 +96,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
9796 TORCH_CHECK (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP ,
9897 " topk_indices output cannot be combined with non-default routing_map_format; "
9998 " dense top-k indices are returned instead of a routing map." );
100- check_dense_topk_indices (topk_indices.value (), logits, num_tokens , topk);
99+ check_dense_topk_indices (topk_indices.value (), logits, sizes. slice ( 0 , sizes. size () - 1 ) , topk);
101100 }
102101
103102 // Reformat the input to make it compatible with the kernel
@@ -179,7 +178,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter
179178 TORCH_CHECK (topk > 0 && topk <= num_experts, " topk must be in [1, num_experts], got topk=" , topk,
180179 " num_experts=" , num_experts);
181180 if (use_dense_indices) {
182- check_dense_topk_indices (routing_map, grad_probs, num_tokens , topk);
181+ check_dense_topk_indices (routing_map, grad_probs, sizes. slice ( 0 , sizes. size () - 1 ) , topk);
183182 }
184183
185184 auto scaling_factor_value = scaling_factor.has_value () ? scaling_factor.value () : 1 .0f ;
0 commit comments