Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ PyTorch

.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

.. autoapifunction:: transformer_engine.pytorch.interleave_glu_tensor

.. autoapifunction:: transformer_engine.pytorch.deinterleave_glu_tensor

Recipe availability
-------------------

Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import is_bf16_available
from transformer_engine.pytorch.utils import deinterleave_glu_tensor
from transformer_engine.pytorch.utils import interleave_glu_tensor
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
Expand Down
105 changes: 104 additions & 1 deletion transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor


__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"]
__all__ = [
"get_device_compute_capability",
"get_cudnn_version",
"is_bf16_available",
"deinterleave_glu_tensor",
"interleave_glu_tensor",
]


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -82,6 +88,103 @@ def get_device_compute_capability() -> Tuple[int, int]:
return _get_device_compute_capability(torch.cuda.current_device())


def deinterleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor:
"""Convert a block-interleaved GLU fc1 tensor to contiguous gate/linear layout.

Fused GLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU`
with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful.

layout along dimension 0. Checkpoints and frameworks such as Megatron-LM typically
store the gate (``W``) and linear (``V``) halves as two contiguous blocks
``[W_all, V_all]``. This helper reorders along dimension 0 without changing
the total shape.

**Layouts along dimension 0** (``k = interleave_size``):

* **Block-interleaved (input):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]``
* **Contiguous (output):** ``[W_all, V_all]``

The same convention applies to ``linear_fc1.weight`` (dimension 0 plus any
remaining dimensions) and ``linear_fc1.bias`` (dimension 0 only).

Parameters
----------
tensor : torch.Tensor
Tensor in block-interleaved layout. The length of dimension 0 must be
divisible by ``2 * interleave_size``.
interleave_size : int
Number of rows (for weights) or elements (for bias) per gate/linear block.
Fused TE GLU paths commonly use ``32``.

Returns
-------
torch.Tensor
A new tensor with the same shape as ``tensor`` and contiguous
``[W_all, V_all]`` ordering along dimension 0.

See Also
--------
:func:`interleave_glu_tensor` : Inverse transformation (contiguous -> block-interleaved).
"""
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Comment on lines +128 to +136
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing input validation means users get a cryptic PyTorch reshape error instead of an actionable message. If tensor.shape[0] is not divisible by 2 * interleave_size, the call to reshape raises something like "shape '[N, 2, k, ...]' is invalid for input of size X" with no indication of which argument is wrong. An upfront check surfaces the real constraint immediately.

Suggested change
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
if interleave_size <= 0:
raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}")
if tensor.shape[0] % (2 * interleave_size) != 0:
raise ValueError(
f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by "
f"2 * interleave_size ({2 * interleave_size})"
)
shape = tensor.shape
x = tensor.reshape(
shape[0] // (2 * interleave_size),
2,
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)



def interleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor:
Comment thread
ksivaman marked this conversation as resolved.
"""Convert a contiguous SwiGLU fc1 tensor to block-interleaved layout.

This is the inverse of :func:`deinterleave_glu_tensor`. Use it when loading
contiguous ``[W_all, V_all]`` checkpoints into a module that uses fused
interleaved SwiGLU (``glu_interleave_size`` on the activation op).

**Layouts along dimension 0** (``k = interleave_size``):

* **Contiguous (input):** ``[W_all, V_all]``
* **Block-interleaved (output):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]``

Parameters
----------
tensor : torch.Tensor
Tensor in contiguous gate/linear layout. The length of dimension 0 must be
divisible by ``2 * interleave_size``.
interleave_size : int
Number of rows (for weights) or elements (for bias) per gate/linear block.
Must match the ``glu_interleave_size`` used by the fused SwiGLU op.

Returns
-------
torch.Tensor
A new tensor with the same shape as ``tensor`` and block-interleaved
ordering along dimension 0.

See Also
--------
:func:`deinterleave_glu_tensor` : Inverse transformation (block-interleaved -> contiguous).
"""
if interleave_size <= 0:
raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}")
if tensor.shape[0] % (2 * interleave_size) != 0:
raise ValueError(
f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by "
f"2 * interleave_size ({2 * interleave_size})"
)
shape = tensor.shape
x = tensor.reshape(
2,
shape[0] // (2 * interleave_size),
interleave_size,
*shape[1:],
)
x = x.transpose(0, 1).contiguous()
return x.reshape(shape)
Comment thread
ksivaman marked this conversation as resolved.
Comment on lines +178 to +185
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably improve perf with tex.swap_first_dims. That said, I'd expect perf is not critical if this is primarily for checkpointing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different



def resolve_grouped_linear_single_param_flags(
single_grouped_weight: bool,
single_grouped_bias: bool,
Expand Down
Loading