[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops#3078
Conversation
Signed-off-by: ksivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR adds
Confidence Score: 5/5Safe to merge — the two new tensor-reordering utilities are mathematically correct inverses and the change is additive with no impact on existing paths. Both interleave_glu_tensor and deinterleave_glu_tensor apply only reshape, transpose, and contiguous operations on dimension 0; tracing through concrete examples confirms they produce the correct block-interleaved and contiguous layouts respectively. The change touches no existing logic — it only adds new exports and documentation. transformer_engine/pytorch/utils.py — deinterleave_glu_tensor is missing the input-validation guards that its counterpart carries, which has already been flagged in earlier review rounds. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant interleave_glu_tensor
participant deinterleave_glu_tensor
participant FusedGroupedMLP
Note over User: Checkpoint loading
User->>interleave_glu_tensor: tensor [W_all, V_all] (contiguous)
interleave_glu_tensor->>interleave_glu_tensor: reshape(2, n, k, ...) → transpose(0,1) → contiguous → reshape
interleave_glu_tensor-->>User: tensor [W0,V0,W1,V1,...] (interleaved)
User->>FusedGroupedMLP: load interleaved weight
Note over User: Checkpoint saving
FusedGroupedMLP-->>User: tensor [W0,V0,W1,V1,...] (interleaved)
User->>deinterleave_glu_tensor: tensor [W0,V0,W1,V1,...] (interleaved)
deinterleave_glu_tensor->>deinterleave_glu_tensor: reshape(n, 2, k, ...) → transpose(0,1) → contiguous → reshape
deinterleave_glu_tensor-->>User: tensor [W_all, V_all] (contiguous)
Reviews (3): Last reviewed commit: "Update transformer_engine/pytorch/utils...." | Re-trigger Greptile |
| shape = tensor.shape | ||
| x = tensor.reshape( | ||
| shape[0] // (2 * interleave_size), | ||
| 2, | ||
| interleave_size, | ||
| *shape[1:], | ||
| ) | ||
| x = x.transpose(0, 1).contiguous() | ||
| return x.reshape(shape) |
There was a problem hiding this comment.
Missing input validation means users get a cryptic PyTorch reshape error instead of an actionable message. If
tensor.shape[0] is not divisible by 2 * interleave_size, the call to reshape raises something like "shape '[N, 2, k, ...]' is invalid for input of size X" with no indication of which argument is wrong. An upfront check surfaces the real constraint immediately.
| shape = tensor.shape | |
| x = tensor.reshape( | |
| shape[0] // (2 * interleave_size), | |
| 2, | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) | |
| if interleave_size <= 0: | |
| raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}") | |
| if tensor.shape[0] % (2 * interleave_size) != 0: | |
| raise ValueError( | |
| f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by " | |
| f"2 * interleave_size ({2 * interleave_size})" | |
| ) | |
| shape = tensor.shape | |
| x = tensor.reshape( | |
| shape[0] // (2 * interleave_size), | |
| 2, | |
| interleave_size, | |
| *shape[1:], | |
| ) | |
| x = x.transpose(0, 1).contiguous() | |
| return x.reshape(shape) |
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. Left few minor comments.
| """Convert a block-interleaved SwiGLU fc1 tensor to contiguous gate/linear layout. | ||
|
|
||
| Fused SwiGLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU` | ||
| with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved |
There was a problem hiding this comment.
Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful.
| x = tensor.reshape( | ||
| 2, | ||
| shape[0] // (2 * interleave_size), | ||
| interleave_size, | ||
| *shape[1:], | ||
| ) | ||
| x = x.transpose(0, 1).contiguous() | ||
| return x.reshape(shape) |
There was a problem hiding this comment.
I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that
There was a problem hiding this comment.
We could probably improve perf with tex.swap_first_dims. That said, I'd expect perf is not critical if this is primarily for checkpointing.
There was a problem hiding this comment.
I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different
Signed-off-by: ksivamani <ksivamani@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
Add utilities to be used to convert FC1 weight format when loading checkpoints. The fused grouped MLP op uses interleaved format for FC1 weight. This functionality was originally added to NVIDIA-NeMo/Megatron-Bridge#2841.
Type of change
Changes
Checklist: