diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 99d850d04d..1612b3234b 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -55,6 +55,10 @@ PyTorch .. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy +.. autoapifunction:: transformer_engine.pytorch.interleave_glu_tensor + +.. autoapifunction:: transformer_engine.pytorch.deinterleave_glu_tensor + Recipe availability ------------------- diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7653d5992e..775200ce01 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -54,6 +54,8 @@ from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available +from transformer_engine.pytorch.utils import deinterleave_glu_tensor +from transformer_engine.pytorch.utils import interleave_glu_tensor from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index fd8f817b33..cfb21e7bff 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -17,7 +17,13 @@ from ..debug.pytorch.debug_quantization import DebugQuantizedTensor -__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] +__all__ = [ + "get_device_compute_capability", + "get_cudnn_version", + "is_bf16_available", + "deinterleave_glu_tensor", + "interleave_glu_tensor", +] @functools.lru_cache(maxsize=None) @@ -82,6 +88,103 @@ def get_device_compute_capability() -> Tuple[int, int]: return _get_device_compute_capability(torch.cuda.current_device()) +def deinterleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor: + """Convert a block-interleaved GLU fc1 tensor to contiguous gate/linear layout. + + Fused GLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU` + with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved + layout along dimension 0. Checkpoints and frameworks such as Megatron-LM typically + store the gate (``W``) and linear (``V``) halves as two contiguous blocks + ``[W_all, V_all]``. This helper reorders along dimension 0 without changing + the total shape. + + **Layouts along dimension 0** (``k = interleave_size``): + + * **Block-interleaved (input):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]`` + * **Contiguous (output):** ``[W_all, V_all]`` + + The same convention applies to ``linear_fc1.weight`` (dimension 0 plus any + remaining dimensions) and ``linear_fc1.bias`` (dimension 0 only). + + Parameters + ---------- + tensor : torch.Tensor + Tensor in block-interleaved layout. The length of dimension 0 must be + divisible by ``2 * interleave_size``. + interleave_size : int + Number of rows (for weights) or elements (for bias) per gate/linear block. + Fused TE GLU paths commonly use ``32``. + + Returns + ------- + torch.Tensor + A new tensor with the same shape as ``tensor`` and contiguous + ``[W_all, V_all]`` ordering along dimension 0. + + See Also + -------- + :func:`interleave_glu_tensor` : Inverse transformation (contiguous -> block-interleaved). + """ + shape = tensor.shape + x = tensor.reshape( + shape[0] // (2 * interleave_size), + 2, + interleave_size, + *shape[1:], + ) + x = x.transpose(0, 1).contiguous() + return x.reshape(shape) + + +def interleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor: + """Convert a contiguous SwiGLU fc1 tensor to block-interleaved layout. + + This is the inverse of :func:`deinterleave_glu_tensor`. Use it when loading + contiguous ``[W_all, V_all]`` checkpoints into a module that uses fused + interleaved SwiGLU (``glu_interleave_size`` on the activation op). + + **Layouts along dimension 0** (``k = interleave_size``): + + * **Contiguous (input):** ``[W_all, V_all]`` + * **Block-interleaved (output):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]`` + + Parameters + ---------- + tensor : torch.Tensor + Tensor in contiguous gate/linear layout. The length of dimension 0 must be + divisible by ``2 * interleave_size``. + interleave_size : int + Number of rows (for weights) or elements (for bias) per gate/linear block. + Must match the ``glu_interleave_size`` used by the fused SwiGLU op. + + Returns + ------- + torch.Tensor + A new tensor with the same shape as ``tensor`` and block-interleaved + ordering along dimension 0. + + See Also + -------- + :func:`deinterleave_glu_tensor` : Inverse transformation (block-interleaved -> contiguous). + """ + if interleave_size <= 0: + raise ValueError(f"interleave_size must be a positive integer, got {interleave_size}") + if tensor.shape[0] % (2 * interleave_size) != 0: + raise ValueError( + f"tensor dimension 0 ({tensor.shape[0]}) must be divisible by " + f"2 * interleave_size ({2 * interleave_size})" + ) + shape = tensor.shape + x = tensor.reshape( + 2, + shape[0] // (2 * interleave_size), + interleave_size, + *shape[1:], + ) + x = x.transpose(0, 1).contiguous() + return x.reshape(shape) + + def resolve_grouped_linear_single_param_flags( single_grouped_weight: bool, single_grouped_bias: bool,