Skip to content

[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078

Open
ksivaman wants to merge 4 commits into
NVIDIA:mainfrom
ksivaman:interleave_deinterleave_glu_tensor
Open

[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078
ksivaman wants to merge 4 commits into
NVIDIA:mainfrom
ksivaman:interleave_deinterleave_glu_tensor

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Jun 3, 2026

Description

Add utilities to be used to convert FC1 weight format when loading checkpoints. The fused grouped MLP op uses interleaved format for FC1 weight. This functionality was originally added to NVIDIA-NeMo/Megatron-Bridge#2841.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add utility functions to interleave and de-interleave FC1 weight.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: ksivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from vthumbe1503 June 3, 2026 01:10
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 3, 2026

Greptile Summary

This PR adds interleave_glu_tensor and deinterleave_glu_tensor as public utilities to convert FC1 weight tensors between the contiguous [W_all, V_all] checkpoint format and the block-interleaved layout expected by the fused grouped MLP SwiGLU kernel.

  • deinterleave_glu_tensor reshapes to (n_blocks, 2, k, ...), transposes axes 0 and 1, and flattens back — correctly producing [W_all, V_all] from [W0, V0, W1, V1, ...].
  • interleave_glu_tensor applies the symmetric inverse: reshapes to (2, n_blocks, k, ...), transposes, and flattens — producing the block-interleaved layout from a contiguous checkpoint.
  • Both functions are exported via __all__, added to __init__.py, and documented in pytorch.rst.

Confidence Score: 5/5

Safe to merge — the two new tensor-reordering utilities are mathematically correct inverses and the change is additive with no impact on existing paths.

Both interleave_glu_tensor and deinterleave_glu_tensor apply only reshape, transpose, and contiguous operations on dimension 0; tracing through concrete examples confirms they produce the correct block-interleaved and contiguous layouts respectively. The change touches no existing logic — it only adds new exports and documentation.

transformer_engine/pytorch/utils.py — deinterleave_glu_tensor is missing the input-validation guards that its counterpart carries, which has already been flagged in earlier review rounds.

Important Files Changed

Filename Overview
transformer_engine/pytorch/utils.py Adds deinterleave_glu_tensor and interleave_glu_tensor; logic is correct and they are proper mathematical inverses, but deinterleave_glu_tensor lacks the input-validation guards present in its counterpart.
transformer_engine/pytorch/init.py Exports the two new utility functions; no issues found.
docs/api/pytorch.rst Adds autoapifunction entries for both new helpers in the correct section; no issues found.

Sequence Diagram

sequenceDiagram
    participant User
    participant interleave_glu_tensor
    participant deinterleave_glu_tensor
    participant FusedGroupedMLP

    Note over User: Checkpoint loading
    User->>interleave_glu_tensor: tensor [W_all, V_all] (contiguous)
    interleave_glu_tensor->>interleave_glu_tensor: reshape(2, n, k, ...) → transpose(0,1) → contiguous → reshape
    interleave_glu_tensor-->>User: tensor [W0,V0,W1,V1,...] (interleaved)
    User->>FusedGroupedMLP: load interleaved weight

    Note over User: Checkpoint saving
    FusedGroupedMLP-->>User: tensor [W0,V0,W1,V1,...] (interleaved)
    User->>deinterleave_glu_tensor: tensor [W0,V0,W1,V1,...] (interleaved)
    deinterleave_glu_tensor->>deinterleave_glu_tensor: reshape(n, 2, k, ...) → transpose(0,1) → contiguous → reshape
    deinterleave_glu_tensor-->>User: tensor [W_all, V_all] (contiguous)
Loading

Reviews (3): Last reviewed commit: "Update transformer_engine/pytorch/utils...." | Re-trigger Greptile

Comment on lines +128 to +136
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing input validation means users get a cryptic PyTorch reshape error instead of an actionable message. If tensor.shape[0] is not divisible by 2 * interleave_size, the call to reshape raises something like "shape '[N, 2, k, ...]' is invalid for input of size X" with no indication of which argument is wrong. An upfront check surfaces the real constraint immediately.

Suggested change
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
if interleave_size <= 0:
raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}")
if tensor.shape[0] % (2 * interleave_size) != 0:
raise ValueError(
f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by "
f"2 * interleave_size ({2 * interleave_size})"
)
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)

Comment thread transformer_engine/pytorch/utils.py
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left few minor comments.

"""Convert a block-interleaved SwiGLU fc1 tensor to contiguous gate/linear layout.

Fused SwiGLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU`
with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful.

Comment on lines +171 to +178
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably improve perf with tex.swap_first_dims. That said, I'd expect perf is not critical if this is primarily for checkpointing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different

Comment thread transformer_engine/pytorch/utils.py
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
Comment thread transformer_engine/pytorch/utils.py Outdated
ksivaman and others added 3 commits June 3, 2026 22:46
Signed-off-by: ksivamani <ksivamani@nvidia.com>

Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants