-
Notifications
You must be signed in to change notification settings - Fork 738
[PyTorch] Expose interleave and de-interleave function for GLU tensor for fused grouped MLP via fused ops #3078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
af05698
3db2eb1
aa76b73
27f85ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,7 +17,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||
| from ..debug.pytorch.debug_quantization import DebugQuantizedTensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| __all__ = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "get_device_compute_capability", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "get_cudnn_version", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "is_bf16_available", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "deinterleave_glu_tensor", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| "interleave_glu_tensor", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -82,6 +88,96 @@ def get_device_compute_capability() -> Tuple[int, int]: | |||||||||||||||||||||||||||||||||||||||||||||||||||
| return _get_device_compute_capability(torch.cuda.current_device()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def deinterleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Convert a block-interleaved SwiGLU fc1 tensor to contiguous gate/linear layout. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
ksivaman marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Fused SwiGLU kernels (for example :class:`~transformer_engine.pytorch.ops.SwiGLU` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
ksivaman marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| with ``glu_interleave_size`` set) expect fc1 weights in a block-interleaved | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isnt it activations that need SwiGLU? Is it that we need to interleave the weights the same way as the activations? Clarification in the comment would be helpful. |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| layout along dimension 0. Checkpoints and frameworks such as Megatron-LM typically | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| store the gate (``W``) and linear (``V``) halves as two contiguous blocks | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``[W_all, V_all]``. This helper reorders along dimension 0 without changing | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| the total shape. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| **Layouts along dimension 0** (``k = interleave_size``): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| * **Block-interleaved (input):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| * **Contiguous (output):** ``[W_all, V_all]`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| The same convention applies to ``linear_fc1.weight`` (dimension 0 plus any | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| remaining dimensions) and ``linear_fc1.bias`` (dimension 0 only). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Tensor in block-interleaved layout. The length of dimension 0 must be | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| divisible by ``2 * interleave_size``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_size : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Number of rows (for weights) or elements (for bias) per gate/linear block. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Fused TE SwiGLU paths commonly use ``32``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
ksivaman marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| A new tensor with the same shape as ``tensor`` and contiguous | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``[W_all, V_all]`` ordering along dimension 0. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| See Also | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| -------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| :func:`interleave_glu_tensor` : Inverse transformation (contiguous -> block-interleaved). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape = tensor.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = tensor.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape[0] // (2 * interleave_size), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| *shape[1:], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.transpose(0, 1).contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return x.reshape(shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+128
to
+136
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def interleave_glu_tensor(tensor: torch.Tensor, interleave_size: int) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
ksivaman marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Convert a contiguous SwiGLU fc1 tensor to block-interleaved layout. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| This is the inverse of :func:`deinterleave_glu_tensor`. Use it when loading | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| contiguous ``[W_all, V_all]`` checkpoints into a module that uses fused | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleaved SwiGLU (``glu_interleave_size`` on the activation op). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| **Layouts along dimension 0** (``k = interleave_size``): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| * **Contiguous (input):** ``[W_all, V_all]`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| * **Block-interleaved (output):** ``[W0:k, V0:k, Wk:2k, Vk:2k, ...]`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Tensor in contiguous gate/linear layout. The length of dimension 0 must be | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| divisible by ``2 * interleave_size``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_size : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Number of rows (for weights) or elements (for bias) per gate/linear block. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Must match the ``glu_interleave_size`` used by the fused SwiGLU op. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| A new tensor with the same shape as ``tensor`` and block-interleaved | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ordering along dimension 0. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| See Also | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| -------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| :func:`deinterleave_glu_tensor` : Inverse transformation (block-interleaved -> contiguous). | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape = tensor.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = tensor.reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape[0] // (2 * interleave_size), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| *shape[1:], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.transpose(0, 1).contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return x.reshape(shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
ksivaman marked this conversation as resolved.
Comment on lines
+178
to
+185
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious if we have done perf analysis of this function with and without torch compile and whether it is worth jitting it with torch.compile. I would assume that would generate a triton kernel(not sure if it is a performant one?). Although I wouldnt block this PR for that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could probably improve perf with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking whether we could reuse this for activations as well. But seems like not. The logic in terms of dim ordering is a bit different
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def resolve_grouped_linear_single_param_flags( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| single_grouped_weight: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| single_grouped_bias: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.