Skip to content

Commit 49c6553

Browse files
committed
[PyTorch] Preserve router leading dimensions
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
1 parent 5ad179b commit 49c6553

3 files changed

Lines changed: 37 additions & 22 deletions

File tree

tests/pytorch/test_fused_router.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,33 @@ def test_topk_softmax(
390390
)
391391

392392

393+
@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16])
394+
def test_topk_preserves_leading_dims(topk_index_dtype):
395+
num_tokens = 128
396+
num_experts = 32
397+
topk = 4
398+
logits = torch.randn(num_tokens, 2, num_experts, device="cuda", dtype=torch.float32)
399+
topk_indices = None
400+
if topk_index_dtype is not None:
401+
topk_indices = torch.empty(num_tokens, 2, topk, device="cuda", dtype=topk_index_dtype)
402+
403+
probs, routing_output = fused_topk_with_score_function(
404+
logits=logits,
405+
topk=topk,
406+
use_pre_softmax=False,
407+
num_groups=None,
408+
group_topk=None,
409+
scaling_factor=None,
410+
score_function="softmax",
411+
expert_bias=None,
412+
topk_indices=topk_indices,
413+
)
414+
415+
assert probs.shape == logits.shape
416+
expected_routing_shape = topk_indices.shape if topk_indices is not None else logits.shape
417+
assert routing_output.shape == expected_routing_shape
418+
419+
393420
@pytest.mark.parametrize("dtype", [torch.float32])
394421
@pytest.mark.parametrize("num_tokens", [2048, 7168])
395422
@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32])

transformer_engine/pytorch/csrc/extensions/router.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,18 @@ static bool is_supported_dense_index_dtype(at::ScalarType dtype) {
4747
}
4848

4949
static void check_dense_topk_indices(const at::Tensor &topk_indices, const at::Tensor &ref,
50-
int64_t num_tokens, int topk) {
50+
c10::IntArrayRef leading_dims, int topk) {
5151
TORCH_CHECK(topk_indices.is_cuda(), "topk_indices must be a CUDA tensor");
5252
TORCH_CHECK(topk_indices.device() == ref.device(), "topk_indices must be on the same device as ",
5353
"the logits/grad tensor");
5454
TORCH_CHECK(topk_indices.is_contiguous(), "topk_indices must be contiguous");
5555
TORCH_CHECK(is_supported_dense_index_dtype(topk_indices.scalar_type()),
5656
"topk_indices dtype must be int16, int32, or int64, got ",
5757
topk_indices.scalar_type());
58-
TORCH_CHECK(topk_indices.numel() == num_tokens * static_cast<int64_t>(topk),
59-
"topk_indices must contain num_tokens * topk elements, got ", topk_indices.numel(),
60-
" but expected ", num_tokens * static_cast<int64_t>(topk));
61-
TORCH_CHECK(topk_indices.dim() >= 1 && topk_indices.size(-1) == topk,
62-
"topk_indices last dimension must be topk=", topk, ", got shape ",
58+
std::vector<int64_t> expected_shape(leading_dims.begin(), leading_dims.end());
59+
expected_shape.push_back(static_cast<int64_t>(topk));
60+
TORCH_CHECK(topk_indices.sizes() == expected_shape,
61+
"topk_indices shape must be [*leading_dims, topk]=", expected_shape, ", got ",
6362
topk_indices.sizes());
6463
}
6564

@@ -97,7 +96,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
9796
TORCH_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
9897
"topk_indices output cannot be combined with non-default routing_map_format; "
9998
"dense top-k indices are returned instead of a routing map.");
100-
check_dense_topk_indices(topk_indices.value(), logits, num_tokens, topk);
99+
check_dense_topk_indices(topk_indices.value(), logits, sizes.slice(0, sizes.size() - 1), topk);
101100
}
102101

103102
// Reformat the input to make it compatible with the kernel
@@ -179,7 +178,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter
179178
TORCH_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk,
180179
" num_experts=", num_experts);
181180
if (use_dense_indices) {
182-
check_dense_topk_indices(routing_map, grad_probs, num_tokens, topk);
181+
check_dense_topk_indices(routing_map, grad_probs, sizes.slice(0, sizes.size() - 1), topk);
183182
}
184183

185184
auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;

transformer_engine/pytorch/router.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,6 @@ def forward(
8686
topk_indices: Optional[torch.Tensor],
8787
):
8888
# pylint: disable=missing-function-docstring
89-
tensor_shape = logits.shape
90-
logits = logits.view(-1, tensor_shape[-1])
91-
num_tokens = logits.size(0)
92-
num_experts = logits.size(1)
9389
probs, routing_output, intermediate_output = tex.fused_topk_with_score_function_fwd(
9490
logits,
9591
topk,
@@ -104,31 +100,25 @@ def forward(
104100
)
105101
if topk_indices is not None:
106102
routing_output = topk_indices
107-
probs = probs.view(tensor_shape)
108103
if topk_indices is not None:
109104
ctx.mark_dirty(topk_indices)
110105
ctx.mark_non_differentiable(routing_output)
111106
ctx.save_for_backward(routing_output, intermediate_output)
112-
ctx.num_tokens = num_tokens
113-
ctx.num_experts = num_experts
114-
ctx.tensor_shape = tensor_shape
115107
ctx.use_pre_softmax = use_pre_softmax
116108
ctx.topk = topk
117109
ctx.scaling_factor = scaling_factor
118110
ctx.score_function = score_function
119111
ctx.routing_map_format = routing_map_format
120-
ctx.logits_dtype = logits.dtype
121112
ctx.use_dense_indices = topk_indices is not None
122113
return probs, routing_output
123114

124115
@staticmethod
125116
def backward(ctx, grad_probs, _):
126117
# pylint: disable=missing-function-docstring
127118
routing_map, intermediate_output = ctx.saved_tensors
128-
grad_probs = grad_probs.contiguous().view(-1, ctx.tensor_shape[-1])
129-
grad_logits = torch.empty(
130-
(ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_probs.device
131-
)
119+
if not grad_probs.is_contiguous():
120+
grad_probs = grad_probs.contiguous()
121+
grad_logits = torch.empty_like(grad_probs)
132122
tex.fused_topk_with_score_function_bwd(
133123
routing_map,
134124
intermediate_output,
@@ -141,7 +131,6 @@ def backward(ctx, grad_probs, _):
141131
ctx.use_dense_indices,
142132
ctx.routing_map_format,
143133
)
144-
grad_logits = grad_logits.view(ctx.tensor_shape)
145134
return grad_logits, None, None, None, None, None, None, None, None, None
146135

147136

0 commit comments

Comments
 (0)