Skip to content

Bug: Shape mismatch and invalid tensor operations in fused experts path of MoETransformerObserver #11

@naoufelito

Description

@naoufelito

Describe the bug

There are critical bugs in the tensor shape logic for the fused_experts path inside MoETransformerObserver._hook_factory in src/reap/observer.py. These lead to shape mismatches and invalid tensor operations which will reliably throw runtime errors.

Details and Example Location

  • At line 356:

    _, router_scores = output  # (num_experts, total_tokens)

    The comment claims router_scores has shape (num_experts, total_tokens). However, based on the context (use of module.router(flat_input) and standard MoE conventions), both router_scores and router_logits should have shape (total_tokens, num_experts).

  • Lines 360-365:

    router_indices = (
        torch.arange(batch_size * sequence_length, device=device)
        .view(1, -1)
        .expand(router_scores.size(0), -1)
    )
    router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)

    Given the above, router_scores.size(0) is likely total_tokens, so router_indices ends up with shape (total_tokens, total_tokens) → then becomes (total_tokens*total_tokens, 1) and finally (total_tokens*total_tokens, hidden_dim) after expansion. This is almost certainly not what is intended.

  • The use of router_indices in the subsequent torch.gather will therefore gather far too many rows, almost certainly leading to index errors or memory issues.

  • Finally, when doing

    activations = routed_out.view(num_experts, *flat_input.shape)  # (num_experts, total_tokens, hidden_dim)

    routed_out's total number of elements will not match the target view shape, causing immediate RuntimeError due to shape mismatch unless by coincidence num_experts == total_tokens, which never holds in practice.

To Reproduce

Run any model where MoETransformerObserver is used with hook_config.fused_experts=True. The error will occur as soon as this code path is executed.

Expected behavior

Tensor operations should use the correct shapes:

  • The shape of router_logits and router_scores should be (total_tokens, num_experts)
  • The gather/index logic must be fixed to select appropriate tokens for the relevant experts
  • The final .view(num_experts, *flat_input.shape) must match the actual size of the tensor being reshaped

Additional context

  • This logic is correct in other (non-fused) branches, but not in the fused experts branch
  • The root cause seems to be incorrect assumptions about tensor shapes carrying over through several lines

Recommendation:

  • Revisit the shape logic for all tensors in this block
  • Refactor using explicit checks and tests for tensor shapes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions