Describe the bug
There are critical bugs in the tensor shape logic for the fused_experts path inside MoETransformerObserver._hook_factory in src/reap/observer.py. These lead to shape mismatches and invalid tensor operations which will reliably throw runtime errors.
Details and Example Location
-
At line 356:
_, router_scores = output # (num_experts, total_tokens)
The comment claims router_scores has shape (num_experts, total_tokens). However, based on the context (use of module.router(flat_input) and standard MoE conventions), both router_scores and router_logits should have shape (total_tokens, num_experts).
-
Lines 360-365:
router_indices = (
torch.arange(batch_size * sequence_length, device=device)
.view(1, -1)
.expand(router_scores.size(0), -1)
)
router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
Given the above, router_scores.size(0) is likely total_tokens, so router_indices ends up with shape (total_tokens, total_tokens) → then becomes (total_tokens*total_tokens, 1) and finally (total_tokens*total_tokens, hidden_dim) after expansion. This is almost certainly not what is intended.
-
The use of router_indices in the subsequent torch.gather will therefore gather far too many rows, almost certainly leading to index errors or memory issues.
-
Finally, when doing
activations = routed_out.view(num_experts, *flat_input.shape) # (num_experts, total_tokens, hidden_dim)
routed_out's total number of elements will not match the target view shape, causing immediate RuntimeError due to shape mismatch unless by coincidence num_experts == total_tokens, which never holds in practice.
To Reproduce
Run any model where MoETransformerObserver is used with hook_config.fused_experts=True. The error will occur as soon as this code path is executed.
Expected behavior
Tensor operations should use the correct shapes:
- The shape of
router_logits and router_scores should be (total_tokens, num_experts)
- The gather/index logic must be fixed to select appropriate tokens for the relevant experts
- The final
.view(num_experts, *flat_input.shape) must match the actual size of the tensor being reshaped
Additional context
- This logic is correct in other (non-fused) branches, but not in the fused experts branch
- The root cause seems to be incorrect assumptions about tensor shapes carrying over through several lines
Recommendation:
- Revisit the shape logic for all tensors in this block
- Refactor using explicit checks and tests for tensor shapes
Describe the bug
There are critical bugs in the tensor shape logic for the
fused_expertspath insideMoETransformerObserver._hook_factoryinsrc/reap/observer.py. These lead to shape mismatches and invalid tensor operations which will reliably throw runtime errors.Details and Example Location
At line 356:
The comment claims
router_scoreshas shape(num_experts, total_tokens). However, based on the context (use ofmodule.router(flat_input)and standard MoE conventions), bothrouter_scoresandrouter_logitsshould have shape(total_tokens, num_experts).Lines 360-365:
Given the above,
router_scores.size(0)is likelytotal_tokens, sorouter_indicesends up with shape(total_tokens, total_tokens)→ then becomes(total_tokens*total_tokens, 1)and finally(total_tokens*total_tokens, hidden_dim)after expansion. This is almost certainly not what is intended.The use of
router_indicesin the subsequenttorch.gatherwill therefore gather far too many rows, almost certainly leading to index errors or memory issues.Finally, when doing
routed_out's total number of elements will not match the target view shape, causing immediateRuntimeErrordue to shape mismatch unless by coincidencenum_experts == total_tokens, which never holds in practice.To Reproduce
Run any model where
MoETransformerObserveris used withhook_config.fused_experts=True. The error will occur as soon as this code path is executed.Expected behavior
Tensor operations should use the correct shapes:
router_logitsandrouter_scoresshould be(total_tokens, num_experts).view(num_experts, *flat_input.shape)must match the actual size of the tensor being reshapedAdditional context
Recommendation: